-
Notifications
You must be signed in to change notification settings - Fork 0
/
sweep_cluster.py
65 lines (52 loc) · 2.23 KB
/
sweep_cluster.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# %%
import argparse
import itertools
from train import main
from wandb_logger import WandBLogger
import numpy as np
project_name = "Complete_v2"
params = {
"NUM_CONNECTOME_PASSES": [3, 4, 5, 6],
"base_lr": np.linspace(1e-5, 1e-2, num=5).tolist(),
"neurons": ["selected", "all"],
"voronoi_criteria": ["R7", "all"],
"random_synapses": [True, False],
}
param_names = sorted(params.keys())
combinations = list(itertools.product(*(params[name] for name in param_names)))
# %%
class SweepConfig:
def __init__(
self, neurons, voronoi_criteria, random_synapses, base_lr, NUM_CONNECTOME_PASSES
):
self.neurons = neurons
self.voronoi_criteria = voronoi_criteria
self.random_synapses = random_synapses
self.base_lr = base_lr
self.NUM_CONNECTOME_PASSES = NUM_CONNECTOME_PASSES
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process sweep parameters.")
parser.add_argument("NUM_CONNECTOME_PASSES", type=int, help="Number of connectome passes")
parser.add_argument("base_lr", type=float, help="Base learning rate")
parser.add_argument("neurons", type=str, help="Type of neurons (selected or all)")
parser.add_argument("random_synapses", type=str, help="Use random synapses (True or False)")
parser.add_argument("voronoi_criteria", type=str, help="Voronoi criteria (R7 or all)")
# Parse the arguments
args = parser.parse_args()
# Create a SweepConfig object from the parsed arguments
sweep_config = SweepConfig(
neurons=args.neurons,
voronoi_criteria=args.voronoi_criteria,
random_synapses=args.random_synapses == "True",
base_lr=args.base_lr,
NUM_CONNECTOME_PASSES=args.NUM_CONNECTOME_PASSES
)
print("Running with configuration:")
print(f"Neurons: {sweep_config.neurons}")
print(f"Voronoi Criteria: {sweep_config.voronoi_criteria}")
print(f"Random Synapses: {sweep_config.random_synapses}")
print(f"Base Learning Rate: {sweep_config.base_lr}")
print(f"Number of Connectome Passes: {sweep_config.NUM_CONNECTOME_PASSES}")
wandb_logger = WandBLogger("adult_complete")
wandb_logger.initialize_run(group="cluster_sweep")
main(wandb_logger, sweep_config=sweep_config)