-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
432 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -165,4 +165,6 @@ results | |
|
||
hparams.yaml | ||
|
||
data/pretrained_models | ||
data/pretrained_models | ||
|
||
*.tar |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
import logging | ||
import os | ||
import shutil | ||
import sys | ||
import cv2 | ||
import json | ||
from itertools import islice | ||
import glob | ||
|
||
import tqdm | ||
import numpy as np | ||
import pandas as pd | ||
from PIL import Image # necessary to avoid PIL.Image error assumption in web_datasets | ||
|
||
from galaxy_datasets.shared import label_metadata | ||
from galaxy_datasets import gz_decals_5 | ||
from galaxy_datasets.transforms import default_transforms | ||
from galaxy_datasets.pytorch import galaxy_dataset | ||
|
||
import webdataset as wds | ||
|
||
def galaxy_to_wds(galaxy: pd.Series, label_cols): | ||
|
||
im = cv2.imread(galaxy['file_loc']) | ||
# cv2 loads BGR for 'history', fix | ||
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) | ||
labels = json.dumps(galaxy[label_cols].to_dict()) | ||
id_str = str(galaxy['id_str']) | ||
# print(id_str) | ||
return { | ||
"__key__": id_str, | ||
"image.jpg": im, | ||
"labels.json": labels | ||
} | ||
|
||
def df_to_wds(df: pd.DataFrame, label_cols, save_loc, n_shards): | ||
df['id_str'] = df['id_str'].str.replace('.', '_') | ||
|
||
shard_dfs = np.array_split(df, n_shards) | ||
print('shards: ', len(shard_dfs)) | ||
print('shard size: ', len(shard_dfs[0])) | ||
for shard_n, shard_df in tqdm.tqdm(enumerate(shard_dfs), total=len(shard_dfs)): | ||
shard_save_loc = save_loc.replace('.tar', f'_{shard_n}_{len(shard_df)}.tar') | ||
print(shard_save_loc) | ||
sink = wds.TarWriter(shard_save_loc) | ||
for index, galaxy in shard_df.iterrows(): | ||
sink.write(galaxy_to_wds(galaxy, label_cols)) | ||
sink.close() | ||
|
||
def check_wds(wds_loc): | ||
|
||
dataset = wds.WebDataset(wds_loc) \ | ||
.decode("rgb") | ||
|
||
for sample in islice(dataset, 0, 3): | ||
print(sample['__key__']) | ||
print(sample['image.jpg'].shape) # .decode(jpg) converts to decoded to 0-1 RGB float, was 0-255 | ||
print(type(sample['labels.json'])) # automatically decoded | ||
|
||
def identity(x): | ||
# no lambda to be pickleable | ||
return x | ||
|
||
|
||
def load_wds(wds_loc): | ||
|
||
augmentation_transform = default_transforms() # A.Compose object | ||
def do_transform(img): | ||
return np.transpose(augmentation_transform(image=np.array(img))["image"], axes=[2, 0, 1]).astype(np.float32) | ||
|
||
dataset = wds.WebDataset(wds_loc) \ | ||
.decode("rgb") \ | ||
.to_tuple('image.jpg', 'labels.json') \ | ||
.map_tuple(do_transform, identity) | ||
|
||
for sample in islice(dataset, 0, 3): | ||
print(sample[0].shape) | ||
print(sample[1]) | ||
|
||
|
||
def main(): | ||
|
||
train_catalog, _ = gz_decals_5(root='/home/walml/repos/zoobot/only_for_me/narval/temp', download=False, train=True) | ||
|
||
train_catalog = train_catalog[:512*64] | ||
label_cols = label_metadata.decals_dr5_ortho_label_cols | ||
|
||
save_loc = "gz_decals_5_train.tar" | ||
|
||
# df_to_wds(train_catalog, label_cols, save_loc, n_shards=8) | ||
|
||
# check_wds(save_loc) | ||
|
||
# load_wds(save_loc) | ||
|
||
import zoobot.pytorch.training.webdatamodule as webdatamodule | ||
|
||
wdm = webdatamodule.WebDataModule( | ||
train_urls=glob.glob(save_loc.replace('.tar', '_*.tar')), | ||
val_urls=[], | ||
# train_size=len(train_catalog), | ||
# val_size=0, | ||
label_cols=label_cols, | ||
num_workers=1 | ||
) | ||
wdm.setup('fit') | ||
|
||
for sample in islice(wdm.train_dataloader(), 0, 3): | ||
images, labels = sample | ||
print(images.shape) | ||
# print(len(labels)) # list of dicts | ||
print(labels) | ||
exit() | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
|
||
main() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import logging | ||
import os | ||
import argparse | ||
import glob | ||
|
||
from pytorch_lightning.loggers import WandbLogger | ||
import wandb | ||
|
||
from zoobot.pytorch.training import train_with_pytorch_lightning | ||
from zoobot.shared import benchmark_datasets, schemas | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
""" | ||
Used to create the PyTorch pretrained weights checkpoints | ||
See .sh file of the same name for args used. | ||
See zoobot/pytorch/examples/minimal_examples.py for a friendlier example | ||
""" | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--save-dir', dest='save_dir', type=str) | ||
# parser.add_argument('--data-dir', dest='data_dir', type=str) | ||
# parser.add_argument('--dataset', dest='dataset', type=str, help='dataset to use, either "gz_decals_dr5" or "gz_evo"') | ||
parser.add_argument('--architecture', dest='architecture_name', default='efficientnet', type=str) | ||
parser.add_argument('--resize-after-crop', dest='resize_after_crop', | ||
type=int, default=224) | ||
parser.add_argument('--color', default=False, action='store_true') | ||
parser.add_argument('--batch-size', dest='batch_size', | ||
default=256, type=int) | ||
parser.add_argument('--gpus', dest='gpus', default=1, type=int) | ||
parser.add_argument('--nodes', dest='nodes', default=1, type=int) | ||
parser.add_argument('--mixed-precision', dest='mixed_precision', | ||
default=False, action='store_true') | ||
parser.add_argument('--debug', dest='debug', | ||
default=False, action='store_true') | ||
parser.add_argument('--wandb', dest='wandb', | ||
default=False, action='store_true') | ||
parser.add_argument('--seed', dest='random_state', default=42, type=int) | ||
args = parser.parse_args() | ||
|
||
""" | ||
debug | ||
python only_for_me/narval/train.py --save-dir only_for_me/narval/debug_models --batch-size 32 --color | ||
""" | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
|
||
random_state = args.random_state | ||
|
||
# if args.nodes > 1: | ||
# # at Manchester, our slurm cluster sets TASKS not NTASKS, which then confuses lightning | ||
# if 'SLURM_NTASKS_PER_NODE' not in os.environ.keys(): | ||
# os.environ['SLURM_NTASKS_PER_NODE'] = os.environ['SLURM_TASKS_PER_NODE'] | ||
# # log the rest to help debug | ||
# logging.info([(x, y) for (x, y) in os.environ.items() if 'SLURM' in x]) | ||
|
||
if args.debug: | ||
download = False | ||
else: | ||
# download = True # for first use | ||
download = False # for speed afterwards | ||
|
||
if os.path.isdir('/home/walml/repos/zoobot'): | ||
search_str = '/home/walml/repos/zoobot/gz_decals_5_train_*.tar' | ||
|
||
else: | ||
search_str = '/home/walml/projects/def-bovy/walml/data/webdatasets/gz_decals_5/gz_decals_5_train_*.tar' | ||
|
||
all_urls = glob.glob(search_str) | ||
assert len(all_urls) > 0, search_str | ||
train_urls, val_urls = all_urls[:6], all_urls[6:] | ||
schema = schemas.decals_dr5_ortho_schema | ||
|
||
|
||
# if args.dataset == 'gz_decals_dr5': | ||
# schema, (train_catalog, val_catalog, test_catalog) = benchmark_datasets.get_gz_decals_dr5_benchmark_dataset(args.data_dir, random_state, download=download) | ||
# elif args.dataset == 'gz_evo': | ||
# schema, (train_catalog, val_catalog, test_catalog) = benchmark_datasets.get_gz_evo_benchmark_dataset(args.data_dir, random_state, download=download) | ||
# else: | ||
# raise ValueError(f'Dataset {args.dataset} not recognised: should be "gz_decals_dr5" or "gz_evo"') | ||
|
||
|
||
# logging.info('First val galaxy: {}'.format(val_catalog.iloc[0]['id_str'])) | ||
|
||
|
||
# debug mode | ||
if args.debug: | ||
logging.warning( | ||
'Using debug mode: cutting urls down to 2') | ||
train_urls = train_urls[:2] | ||
val_urls = val_urls[:2] | ||
epochs = 2 | ||
else: | ||
epochs = 1000 | ||
|
||
if args.wandb: | ||
wandb_logger = WandbLogger( | ||
project='narval', | ||
name=os.path.basename(args.save_dir), | ||
log_model=False | ||
) | ||
else: | ||
wandb_logger = None | ||
|
||
train_with_pytorch_lightning.train_default_zoobot_from_scratch( | ||
save_dir=args.save_dir, | ||
schema=schema, | ||
train_urls = train_urls, | ||
val_urls = val_urls, | ||
test_urls = None, | ||
architecture_name=args.architecture_name, | ||
batch_size=args.batch_size, | ||
epochs=epochs, # rely on early stopping | ||
patience=10, | ||
# augmentation parameters | ||
color=args.color, | ||
resize_after_crop=args.resize_after_crop, | ||
# hardware parameters | ||
gpus=args.gpus, | ||
nodes=args.nodes, | ||
mixed_precision=args.mixed_precision, | ||
wandb_logger=wandb_logger, | ||
prefetch_factor=4, | ||
num_workers=11, # system has 24 cpu, 12 cpu per gpu, leave a little wiggle room | ||
random_state=random_state, | ||
learning_rate=1e-3, | ||
) | ||
|
||
wandb.finish() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.