diff --git a/latest_configs/multi_16x16_rand_randsnap_lrstep1_a.yaml b/latest_configs/multi_16x16_rand_randsnap_lrstep1_a.yaml new file mode 100644 index 0000000..19fd259 --- /dev/null +++ b/latest_configs/multi_16x16_rand_randsnap_lrstep1_a.yaml @@ -0,0 +1,91 @@ + +global: + seed: 10 + nthetas: 65 + beamformer_input: True #True + empirical_input: True + phase_input: True #True + rx_spacing_input: True #True + n_radios: 2 + +datasets: + train_paths: + - /mnt/4tb_ssd/nosig_data/wallarrayv3_2024*.zarr + val_holdout_fraction: 0.2 + val_subsample_fraction: 0.2 + precompute_cache: /home/mouse9911/precompute_cache_chunk16_sept + skip_qc: True + val_snapshots_per_session: 16 + train_snapshots_per_session: 32 + random_snapshot_size: True + snapshots_stride: 1 + snapshots_adjacent_stride: 32 + val_snapshots_adjacent_stride: 16 + random_adjacent_stride: True + batch_size: 64 + shuffle: True + workers: 20 + sigma: 0.25 + scatter_k: 21 + flip: False + double_flip: False + scatter: continuous + + empirical_data_fn: /home/mouse9911/gits/spf/empirical_dists/full.pkl + empirical_individual_radio: False + empirical_symmetry: True + + +logger: + name: wandb + project: 2024_oct22_single_paired_multi + log_every: 100 + plot_every: 15000 + +model: + name: multipairedbeamformer + load_paired: True + skip_connection: False + use_xy: True + transformer: + d_model: 1024 + n_heads: 8 + d_hid: 256 + dropout: 0.0 + n_layers: 4 + paired: + hidden: 1024 + depth: 4 + block: True + bn: True + norm: layer + detach: True + dropout: 0.0 + load_single: True + single: + hidden: 1024 + depth: 4 + block: True + bn: True + norm: layer + detach: True + input_dropout: 0.3 + dropout: 0.0 + +optim: + head_start: 0 + device: cuda + dtype: torch.float32 + resume_step: 0 + epochs: 10 + direct_loss: False + val_every: 5000 + learning_rate: 5.0e-5 + weight_decay: 0.0000000 + amp: False + loss: mse + scheduler_step: 3 + checkpoint_every: 20000 + output: multipaired_checkpoints + save_on: val_multipaired_loss + checkpoint: /home/mouse9911/gits/spf/oct24_paired_checkpoints/best.pth diff --git a/latest_configs/multi_16x16_rand_traj_randsnap_lrstep1_b.yaml b/latest_configs/multi_16x16_rand_traj_randsnap_lrstep1_b.yaml new file mode 100644 index 0000000..907e318 --- /dev/null +++ b/latest_configs/multi_16x16_rand_traj_randsnap_lrstep1_b.yaml @@ -0,0 +1,95 @@ + +global: + seed: 10 + nthetas: 65 + beamformer_input: True #True + empirical_input: True + phase_input: True #True + rx_spacing_input: True #True + n_radios: 2 + +datasets: + train_paths: + - /mnt/4tb_ssd/nosig_data/wallarrayv3_2024*.zarr + val_holdout_fraction: 0.2 + val_subsample_fraction: 0.2 + precompute_cache: /home/mouse9911/precompute_cache_chunk16_sept + skip_qc: True + val_snapshots_per_session: 16 + train_snapshots_per_session: 32 + random_snapshot_size: True + snapshots_stride: 1 + snapshots_adjacent_stride: 32 + val_snapshots_adjacent_stride: 16 + random_adjacent_stride: True + batch_size: 64 + shuffle: True + workers: 20 + sigma: 0.25 + scatter_k: 21 + flip: False + double_flip: False + scatter: continuous + + empirical_data_fn: /home/mouse9911/gits/spf/empirical_dists/full.pkl + empirical_individual_radio: False + empirical_symmetry: True + + +logger: + name: wandb + project: 2024_oct22_single_paired_multi + log_every: 100 + plot_every: 15000 + +model: + name: trajmultipairedbeamformer + load_paired: True + skip_connection: False + latent: 64 + traj_hidden: 512 + traj_layers: 8 + use_xy: True + pred_xy: True + transformer: + d_model: 2048 + n_heads: 8 + d_hid: 256 + dropout: 0.0 + n_layers: 8 + paired: + hidden: 1024 + depth: 4 + block: True + bn: True + norm: layer + detach: True + dropout: 0.0 + load_single: True + single: + hidden: 1024 + depth: 4 + block: True + bn: True + norm: layer + detach: True + input_dropout: 0.3 + dropout: 0.0 + +optim: + head_start: 0 + device: cuda + dtype: torch.float32 + resume_step: 0 + epochs: 10 + direct_loss: False + val_every: 5000 + learning_rate: 1.0e-5 + weight_decay: 0.0000000 + amp: False + loss: mse + scheduler_step: 3 + checkpoint_every: 20000 + output: multipaired_traj_checkpoints + save_on: val_multipaired_loss + checkpoint: /home/mouse9911/gits/spf/oct24_paired_checkpoints/best.pth diff --git a/latest_configs/paired.yaml b/latest_configs/paired.yaml new file mode 100644 index 0000000..a609382 --- /dev/null +++ b/latest_configs/paired.yaml @@ -0,0 +1,72 @@ +datasets: + batch_size: 256 + empirical_data_fn: /home/mouse9911/gits/spf/empirical_dists/full.pkl + empirical_individual_radio: false + empirical_symmetry: true + flip: false + double_flip: True + precompute_cache: /home/mouse9911/precompute_cache_chunk16_sept + scatter: continuous + scatter_k: 21 + shuffle: true + sigma: 0.25 + skip_qc: true + snapshots_adjacent_stride: 1 + train_snapshots_per_session: 1 + val_snapshots_per_session: 1 + random_snapshot_size: False + snapshots_stride: 1 + train_paths: + - /mnt/4tb_ssd/nosig_data/wallarrayv3_2024*.zarr + val_holdout_fraction: 0.2 + val_subsample_fraction: 0.2 + workers: 20 +global: + beamformer_input: true + empirical_input: true + n_radios: 2 + nthetas: 65 + phase_input: true + rx_spacing_input: true + seed: 10 +logger: + log_every: 100 + name: wandb + plot_every: 15000 + project: 2024_nov2_single_paired_multi +model: + block: true + bn: true + depth: 4 + detach: true + dropout: 0.0 + hidden: 1024 + load_single: true + name: pairedbeamformer + norm: layer + single: + block: true + bn: true + depth: 4 + detach: true + dropout: 0.0 + hidden: 1024 + input_dropout: 0.3 + norm: layer +optim: + amp: true + checkpoint: /home/mouse9911/gits/spf/nov2_checkpoints/single_checkpoints_inputdo0p3/best.pth + checkpoint_every: 5000 + device: cuda + direct_loss: false + dtype: torch.float32 + epochs: 60 + head_start: 0 + learning_rate: 0.0002 + loss: mse + output: /home/mouse9911/gits/spf/nov2_checkpoints/paired_checkpoints_inputdo0p3 + resume_step: 0 + save_on: val_paired_loss + scheduler_step: 6 + val_every: 10000 + weight_decay: 0.0 diff --git a/latest_configs/paired_noflip.yaml b/latest_configs/paired_noflip.yaml new file mode 100644 index 0000000..1b3e436 --- /dev/null +++ b/latest_configs/paired_noflip.yaml @@ -0,0 +1,72 @@ +datasets: + batch_size: 256 + empirical_data_fn: /home/mouse9911/gits/spf/empirical_dists/full.pkl + empirical_individual_radio: false + empirical_symmetry: true + flip: false + double_flip: false + precompute_cache: /home/mouse9911/precompute_cache_chunk16_sept + scatter: continuous + scatter_k: 21 + shuffle: true + sigma: 0.25 + skip_qc: true + snapshots_adjacent_stride: 1 + train_snapshots_per_session: 1 + val_snapshots_per_session: 1 + random_snapshot_size: False + snapshots_stride: 1 + train_paths: + - /mnt/4tb_ssd/nosig_data/wallarrayv3_2024*.zarr + val_holdout_fraction: 0.2 + val_subsample_fraction: 0.2 + workers: 20 +global: + beamformer_input: true + empirical_input: true + n_radios: 2 + nthetas: 65 + phase_input: true + rx_spacing_input: true + seed: 10 +logger: + log_every: 100 + name: wandb + plot_every: 15000 + project: 2024_oct22_single_paired_multi +model: + block: true + bn: true + depth: 4 + detach: true + dropout: 0.0 + hidden: 1024 + load_single: true + name: pairedbeamformer + norm: layer + single: + block: true + bn: true + depth: 4 + detach: true + dropout: 0.0 + hidden: 1024 + input_dropout: 0.3 + norm: layer +optim: + amp: true + checkpoint: /home/mouse9911/gits/spf/oct23_single_checkpoints_inputdo0p3_x3/best.pth + checkpoint_every: 5000 + device: cuda + direct_loss: false + dtype: torch.float32 + epochs: 40 + head_start: 0 + learning_rate: 0.0001 + loss: mse + output: oct23_paired_checkpoints_noflip + resume_step: 0 + save_on: val_paired_loss + scheduler_step: 20 + val_every: 10000 + weight_decay: 0.0 diff --git a/latest_configs/single_config.yaml b/latest_configs/single_config.yaml new file mode 100644 index 0000000..e588930 --- /dev/null +++ b/latest_configs/single_config.yaml @@ -0,0 +1,62 @@ +datasets: + batch_size: 256 + empirical_data_fn: /home/mouse9911/gits/spf/empirical_dists/full.pkl + empirical_individual_radio: false + empirical_symmetry: true + flip: true + double_flip: false + precompute_cache: /home/mouse9911/precompute_cache_chunk16_sept + scatter: continuous + scatter_k: 21 + shuffle: true + sigma: 0.25 + skip_qc: true + snapshots_adjacent_stride: 1 + train_snapshots_per_session: 1 + val_snapshots_per_session: 1 + random_snapshot_size: False + snapshots_stride: 1 + train_paths: + - /mnt/4tb_ssd/nosig_data/wallarrayv3_2024*.zarr + val_holdout_fraction: 0.2 + val_subsample_fraction: 0.2 + workers: 20 +global: + beamformer_input: true + empirical_input: true + n_radios: 2 + nthetas: 65 + phase_input: true + rx_spacing_input: true + seed: 10 +logger: + log_every: 100 + name: wandb + plot_every: 15000 + project: 2024_nov2_single_paired_multi +model: + block: true + bn: true + depth: 4 + detach: true + dropout: 0.0 + hidden: 1024 + input_dropout: 0.3 + name: beamformer + norm: layer +optim: + amp: true + checkpoint_every: 5000 + device: cuda + direct_loss: false + dtype: torch.float32 + epochs: 30 + head_start: 0 + learning_rate: 0.002 + loss: mse + output: /home/mouse9911/gits/spf/nov2_checkpoints/single_checkpoints_inputdo0p3 + resume_step: 0 + save_on: val_single_loss + scheduler_step: 7 + val_every: 10000 + weight_decay: 0.0 diff --git a/latest_configs/single_config_noflip.yaml b/latest_configs/single_config_noflip.yaml new file mode 100644 index 0000000..f0f80ed --- /dev/null +++ b/latest_configs/single_config_noflip.yaml @@ -0,0 +1,62 @@ +datasets: + batch_size: 256 + empirical_data_fn: /home/mouse9911/gits/spf/empirical_dists/full.pkl + empirical_individual_radio: false + empirical_symmetry: true + flip: false + double_flip: false + precompute_cache: /home/mouse9911/precompute_cache_chunk16_sept + scatter: continuous + scatter_k: 21 + shuffle: true + sigma: 0.25 + skip_qc: true + snapshots_adjacent_stride: 1 + train_snapshots_per_session: 1 + val_snapshots_per_session: 1 + random_snapshot_size: False + snapshots_stride: 1 + train_paths: + - /mnt/4tb_ssd/nosig_data/wallarrayv3_2024*.zarr + val_holdout_fraction: 0.2 + val_subsample_fraction: 0.2 + workers: 20 +global: + beamformer_input: true + empirical_input: true + n_radios: 2 + nthetas: 65 + phase_input: true + rx_spacing_input: true + seed: 10 +logger: + log_every: 100 + name: wandb + plot_every: 15000 + project: 2024_oct22_single_paired_multi +model: + block: true + bn: true + depth: 4 + detach: true + dropout: 0.0 + hidden: 1024 + input_dropout: 0.3 + name: beamformer + norm: layer +optim: + amp: true + checkpoint_every: 5000 + device: cuda + direct_loss: false + dtype: torch.float32 + epochs: 25 + head_start: 0 + learning_rate: 0.002 + loss: mse + output: oct23_single_checkpoints_inputdo0p3_noflip_x3 + resume_step: 0 + save_on: val_single_loss + scheduler_step: 7 + val_every: 10000 + weight_decay: 0.0 diff --git a/spf/model_training_and_inference/models/single_point_networks_inference.py b/spf/model_training_and_inference/models/single_point_networks_inference.py new file mode 100644 index 0000000..bc7f059 --- /dev/null +++ b/spf/model_training_and_inference/models/single_point_networks_inference.py @@ -0,0 +1,43 @@ +from spf.scripts.train_single_point import ( + load_checkpoint, + load_config_from_fn, + load_model, +) + + +def load_model_and_config_from_config_fn_and_checkpoint(config_fn, checkpoint_fn): + config = load_config_from_fn(config_fn) + config["optim"]["checkpoint"] = checkpoint_fn + m = load_model(config["model"], config["global"]).to(config["optim"]["device"]) + m, _, _, _, _ = load_checkpoint( + checkpoint_fn=config["optim"]["checkpoint"], + config=config, + model=m, + optimizer=None, + scheduler=None, + force_load=True, + ) + return m, config + + +def convert_datasets_config_to_inference(datasets_config, ds_fn): + datasets_config = datasets_config.copy() + datasets_config.update( + { + "batch_size": 1, + "flip": False, + "double_flip": False, + "precompute_cache": "/home/mouse9911/precompute_cache_chunk16_sept", + "shuffle": False, + "skip_qc": True, + "snapshots_adjacent_stride": 1, + "train_snapshots_per_session": 1, + "val_snapshots_per_session": 1, + "random_snapshot_size": False, + "snapshots_stride": 1, + "train_paths": [ds_fn], + "train_on_val": True, + "workers": 1, + } + ) + return datasets_config