Skip to content

Commit

Permalink
webds works, try cluster
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Nov 3, 2023
1 parent 257b4dc commit 6a375e3
Show file tree
Hide file tree
Showing 6 changed files with 432 additions and 31 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,4 +165,6 @@ results

hparams.yaml

data/pretrained_models
data/pretrained_models

*.tar
1 change: 1 addition & 0 deletions only_for_me/narval/finetune.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#SBATCH --cpus-per-task=12
#SBATCH --gres=gpu:a100:2

# https://github.com/webdataset/webdataset-lightning/blob/main/simple_cluster.py
#### SBATCH --mem=32G
#### SBATCH --nodes=1
#### SBATCH --time=0:20:0
Expand Down
120 changes: 120 additions & 0 deletions only_for_me/narval/gz_decals_webdataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import logging
import os
import shutil
import sys
import cv2
import json
from itertools import islice
import glob

import tqdm
import numpy as np
import pandas as pd
from PIL import Image # necessary to avoid PIL.Image error assumption in web_datasets

from galaxy_datasets.shared import label_metadata
from galaxy_datasets import gz_decals_5
from galaxy_datasets.transforms import default_transforms
from galaxy_datasets.pytorch import galaxy_dataset

import webdataset as wds

def galaxy_to_wds(galaxy: pd.Series, label_cols):

im = cv2.imread(galaxy['file_loc'])
# cv2 loads BGR for 'history', fix
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
labels = json.dumps(galaxy[label_cols].to_dict())
id_str = str(galaxy['id_str'])
# print(id_str)
return {
"__key__": id_str,
"image.jpg": im,
"labels.json": labels
}

def df_to_wds(df: pd.DataFrame, label_cols, save_loc, n_shards):
df['id_str'] = df['id_str'].str.replace('.', '_')

shard_dfs = np.array_split(df, n_shards)
print('shards: ', len(shard_dfs))
print('shard size: ', len(shard_dfs[0]))
for shard_n, shard_df in tqdm.tqdm(enumerate(shard_dfs), total=len(shard_dfs)):
shard_save_loc = save_loc.replace('.tar', f'_{shard_n}_{len(shard_df)}.tar')
print(shard_save_loc)
sink = wds.TarWriter(shard_save_loc)
for index, galaxy in shard_df.iterrows():
sink.write(galaxy_to_wds(galaxy, label_cols))
sink.close()

def check_wds(wds_loc):

dataset = wds.WebDataset(wds_loc) \
.decode("rgb")

for sample in islice(dataset, 0, 3):
print(sample['__key__'])
print(sample['image.jpg'].shape) # .decode(jpg) converts to decoded to 0-1 RGB float, was 0-255
print(type(sample['labels.json'])) # automatically decoded

def identity(x):
# no lambda to be pickleable
return x


def load_wds(wds_loc):

augmentation_transform = default_transforms() # A.Compose object
def do_transform(img):
return np.transpose(augmentation_transform(image=np.array(img))["image"], axes=[2, 0, 1]).astype(np.float32)

dataset = wds.WebDataset(wds_loc) \
.decode("rgb") \
.to_tuple('image.jpg', 'labels.json') \
.map_tuple(do_transform, identity)

for sample in islice(dataset, 0, 3):
print(sample[0].shape)
print(sample[1])


def main():

train_catalog, _ = gz_decals_5(root='/home/walml/repos/zoobot/only_for_me/narval/temp', download=False, train=True)

train_catalog = train_catalog[:512*64]
label_cols = label_metadata.decals_dr5_ortho_label_cols

save_loc = "gz_decals_5_train.tar"

# df_to_wds(train_catalog, label_cols, save_loc, n_shards=8)

# check_wds(save_loc)

# load_wds(save_loc)

import zoobot.pytorch.training.webdatamodule as webdatamodule

wdm = webdatamodule.WebDataModule(
train_urls=glob.glob(save_loc.replace('.tar', '_*.tar')),
val_urls=[],
# train_size=len(train_catalog),
# val_size=0,
label_cols=label_cols,
num_workers=1
)
wdm.setup('fit')

for sample in islice(wdm.train_dataloader(), 0, 3):
images, labels = sample
print(images.shape)
# print(len(labels)) # list of dicts
print(labels)
exit()



if __name__ == '__main__':

main()

130 changes: 130 additions & 0 deletions only_for_me/narval/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import logging
import os
import argparse
import glob

from pytorch_lightning.loggers import WandbLogger
import wandb

from zoobot.pytorch.training import train_with_pytorch_lightning
from zoobot.shared import benchmark_datasets, schemas


if __name__ == '__main__':

"""
Used to create the PyTorch pretrained weights checkpoints
See .sh file of the same name for args used.
See zoobot/pytorch/examples/minimal_examples.py for a friendlier example
"""
parser = argparse.ArgumentParser()
parser.add_argument('--save-dir', dest='save_dir', type=str)
# parser.add_argument('--data-dir', dest='data_dir', type=str)
# parser.add_argument('--dataset', dest='dataset', type=str, help='dataset to use, either "gz_decals_dr5" or "gz_evo"')
parser.add_argument('--architecture', dest='architecture_name', default='efficientnet', type=str)
parser.add_argument('--resize-after-crop', dest='resize_after_crop',
type=int, default=224)
parser.add_argument('--color', default=False, action='store_true')
parser.add_argument('--batch-size', dest='batch_size',
default=256, type=int)
parser.add_argument('--gpus', dest='gpus', default=1, type=int)
parser.add_argument('--nodes', dest='nodes', default=1, type=int)
parser.add_argument('--mixed-precision', dest='mixed_precision',
default=False, action='store_true')
parser.add_argument('--debug', dest='debug',
default=False, action='store_true')
parser.add_argument('--wandb', dest='wandb',
default=False, action='store_true')
parser.add_argument('--seed', dest='random_state', default=42, type=int)
args = parser.parse_args()

"""
debug
python only_for_me/narval/train.py --save-dir only_for_me/narval/debug_models --batch-size 32 --color
"""

logging.basicConfig(level=logging.INFO)

random_state = args.random_state

# if args.nodes > 1:
# # at Manchester, our slurm cluster sets TASKS not NTASKS, which then confuses lightning
# if 'SLURM_NTASKS_PER_NODE' not in os.environ.keys():
# os.environ['SLURM_NTASKS_PER_NODE'] = os.environ['SLURM_TASKS_PER_NODE']
# # log the rest to help debug
# logging.info([(x, y) for (x, y) in os.environ.items() if 'SLURM' in x])

if args.debug:
download = False
else:
# download = True # for first use
download = False # for speed afterwards

if os.path.isdir('/home/walml/repos/zoobot'):
search_str = '/home/walml/repos/zoobot/gz_decals_5_train_*.tar'

else:
search_str = '/home/walml/projects/def-bovy/walml/data/webdatasets/gz_decals_5/gz_decals_5_train_*.tar'

all_urls = glob.glob(search_str)
assert len(all_urls) > 0, search_str
train_urls, val_urls = all_urls[:6], all_urls[6:]
schema = schemas.decals_dr5_ortho_schema


# if args.dataset == 'gz_decals_dr5':
# schema, (train_catalog, val_catalog, test_catalog) = benchmark_datasets.get_gz_decals_dr5_benchmark_dataset(args.data_dir, random_state, download=download)
# elif args.dataset == 'gz_evo':
# schema, (train_catalog, val_catalog, test_catalog) = benchmark_datasets.get_gz_evo_benchmark_dataset(args.data_dir, random_state, download=download)
# else:
# raise ValueError(f'Dataset {args.dataset} not recognised: should be "gz_decals_dr5" or "gz_evo"')


# logging.info('First val galaxy: {}'.format(val_catalog.iloc[0]['id_str']))


# debug mode
if args.debug:
logging.warning(
'Using debug mode: cutting urls down to 2')
train_urls = train_urls[:2]
val_urls = val_urls[:2]
epochs = 2
else:
epochs = 1000

if args.wandb:
wandb_logger = WandbLogger(
project='narval',
name=os.path.basename(args.save_dir),
log_model=False
)
else:
wandb_logger = None

train_with_pytorch_lightning.train_default_zoobot_from_scratch(
save_dir=args.save_dir,
schema=schema,
train_urls = train_urls,
val_urls = val_urls,
test_urls = None,
architecture_name=args.architecture_name,
batch_size=args.batch_size,
epochs=epochs, # rely on early stopping
patience=10,
# augmentation parameters
color=args.color,
resize_after_crop=args.resize_after_crop,
# hardware parameters
gpus=args.gpus,
nodes=args.nodes,
mixed_precision=args.mixed_precision,
wandb_logger=wandb_logger,
prefetch_factor=4,
num_workers=11, # system has 24 cpu, 12 cpu per gpu, leave a little wiggle room
random_state=random_state,
learning_rate=1e-3,
)

wandb.finish()
78 changes: 48 additions & 30 deletions zoobot/pytorch/training/train_with_pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from galaxy_datasets.pytorch.galaxy_datamodule import GalaxyDataModule

from zoobot.pytorch.estimators import define_model
from zoobot.pytorch.training import webdatamodule


def train_default_zoobot_from_scratch(
Expand All @@ -22,6 +23,9 @@ def train_default_zoobot_from_scratch(
train_catalog=None,
val_catalog=None,
test_catalog=None,
train_urls=None,
val_urls=None,
test_urls=None,
# training time parameters
epochs=1000,
patience=8,
Expand Down Expand Up @@ -167,22 +171,6 @@ def train_default_zoobot_from_scratch(
Suggest reducing num_workers."""
)


if catalog is not None:
assert train_catalog is None
assert val_catalog is None
assert test_catalog is None
catalogs_to_use = {
'catalog': catalog
}
else:
assert catalog is None
catalogs_to_use = {
'train_catalog': train_catalog,
'val_catalog': val_catalog,
'test_catalog': test_catalog # may be None
}

if wandb_logger is not None:
wandb_logger.log_hyperparams({
'random_state': random_state,
Expand All @@ -201,20 +189,50 @@ def train_default_zoobot_from_scratch(
'framework': 'pytorch'
})

datamodule = GalaxyDataModule(
label_cols=schema.label_cols,
# can take either a catalog (and split it), or a pre-split catalog
**catalogs_to_use,
# augmentations parameters
greyscale=not color,
crop_scale_bounds=crop_scale_bounds,
crop_ratio_bounds=crop_ratio_bounds,
resize_after_crop=resize_after_crop,
# hardware parameters
batch_size=batch_size, # on 2xA100s, 256 with DDP, 512 with distributed (i.e. split batch)
num_workers=num_workers,
prefetch_factor=prefetch_factor
)
# work out what dataset the user has passed
single_catalog = catalog is not None
split_catalogs = train_catalog is not None
webdatasets = train_urls is not None

if single_catalog or split_catalogs:
# this branch will use GalaxyDataModule to load catalogs
assert not webdatasets
if single_catalog:
assert not split_catalogs
data_to_use = {
'catalog': catalog
}
else:
data_to_use = {
'train_catalog': train_catalog,
'val_catalog': val_catalog,
'test_catalog': test_catalog # may be None
}
datamodule = GalaxyDataModule(
label_cols=schema.label_cols,
# can take either a catalog (and split it), or a pre-split catalog
**data_to_use,
# augmentations parameters
greyscale=not color,
crop_scale_bounds=crop_scale_bounds,
crop_ratio_bounds=crop_ratio_bounds,
resize_after_crop=resize_after_crop,
# hardware parameters
batch_size=batch_size, # on 2xA100s, 256 with DDP, 512 with distributed (i.e. split batch)
num_workers=num_workers,
prefetch_factor=prefetch_factor
)
else:
# this branch will use WebDataModule to load premade webdatasets
datamodule = webdatamodule.WebDataModule(
train_urls=train_urls,
val_urls=val_urls,
batch_size=batch_size,
num_workers=num_workers,
label_cols=schema.label_cols
# TODO pass through the rest
)

datamodule.setup(stage='fit')

# these args are automatically logged
Expand Down
Loading

0 comments on commit 6a375e3

Please sign in to comment.