Skip to content

Commit

Permalink
factor out cascade loop into runner class
Browse files Browse the repository at this point in the history
  • Loading branch information
miketynes committed Dec 3, 2024
1 parent 8e3b72f commit 39b84f0
Show file tree
Hide file tree
Showing 2 changed files with 216 additions and 129 deletions.
214 changes: 85 additions & 129 deletions 3_cascade/1_prototype-cascade.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@
"from cascade.utils import canonicalize, apply_calculator\n",
"from cascade.auditor import RandomAuditor\n",
"from cascade.learning.torchani import TorchANI\n",
"from cascade.learning.torchani.build import make_output_nets, make_aev_computer"
"from cascade.learning.torchani.build import make_output_nets, make_aev_computer\n",
"from cascade.runner import SerialCascadeRunner\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
Expand Down Expand Up @@ -99,14 +103,6 @@
"id": "08a8ef58-4ee3-4686-9ff4-3942211e23c2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using Materials Project MACE for MACECalculator with /home/mike/.cache/mace/20231210mace128L0_energy_epoch249model\n",
"Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
Expand All @@ -119,6 +115,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Using Materials Project MACE for MACECalculator with /home/mike/.cache/mace/20231210mace128L0_energy_epoch249model\n",
"Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.\n",
"Default dtype float32 does not match model dtype float64, converting models to float32.\n"
]
}
Expand Down Expand Up @@ -175,21 +173,6 @@
"## Minimum viable cascasde loop"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b325e1fd-2737-4650-8808-e1fb2b5b2810",
"metadata": {},
"outputs": [],
"source": [
"class CascadeThinker:\n",
"\n",
" def __init__(self,\n",
" total_steps: int,\n",
" increment_steps: int,\n",
" auditor: "
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -208,7 +191,11 @@
"name": "stdout",
"output_type": "stream",
"text": [
"On traj 1/2\n",
"**********\n",
"Starting pass 1/10 of cascade loop\n",
"Currently 0 of 2 complete\n",
"Examining trajectory 1 of 2\n",
"Trajectory is trusted, advancing\n",
"Running ML-driven dynamics\n"
]
},
Expand All @@ -224,67 +211,77 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CascadeTrajectory(path=si-diffusion-seed=0.traj, current_timestep=64, last_trusted_timestep=0)\n",
"On traj 2/2\n",
"Examining trajectory 2 of 2\n",
"Trajectory is trusted, advancing\n",
"Running ML-driven dynamics\n",
"CascadeTrajectory(path=si-diffusion-seed=1.traj, current_timestep=64, last_trusted_timestep=0)\n",
"done 0 / 2\n",
"On traj 1/2\n",
"**********\n",
"Starting pass 2/10 of cascade loop\n",
"Currently 0 of 2 complete\n",
"Examining trajectory 1 of 2\n",
"Trajectory has untrusted segment, auditing\n",
"Auditing trajectory\n",
"score < threshold (0.3745401188473625 < 0.5, marking recent segment as trusted\n",
"CascadeTrajectory(path=si-diffusion-seed=0.traj, current_timestep=64, last_trusted_timestep=64)\n",
"On traj 2/2\n",
"Examining trajectory 2 of 2\n",
"Trajectory has untrusted segment, auditing\n",
"Auditing trajectory\n",
"score < threshold (0.034388521115218396 < 0.5, marking recent segment as trusted\n",
"CascadeTrajectory(path=si-diffusion-seed=1.traj, current_timestep=64, last_trusted_timestep=64)\n",
"done 0 / 2\n",
"On traj 1/2\n",
"**********\n",
"Starting pass 3/10 of cascade loop\n",
"Currently 0 of 2 complete\n",
"Examining trajectory 1 of 2\n",
"Trajectory is trusted, advancing\n",
"Running ML-driven dynamics\n",
"CascadeTrajectory(path=si-diffusion-seed=0.traj, current_timestep=128, last_trusted_timestep=64)\n",
"On traj 2/2\n",
"Examining trajectory 2 of 2\n",
"Trajectory is trusted, advancing\n",
"Running ML-driven dynamics\n",
"CascadeTrajectory(path=si-diffusion-seed=1.traj, current_timestep=128, last_trusted_timestep=64)\n",
"done 0 / 2\n",
"On traj 1/2\n",
"**********\n",
"Starting pass 4/10 of cascade loop\n",
"Currently 0 of 2 complete\n",
"Examining trajectory 1 of 2\n",
"Trajectory has untrusted segment, auditing\n",
"Auditing trajectory\n",
"score > threshold (0.6688412526636073 > 0.5), running audit calculations and dropping untrusted segment\n",
"CascadeTrajectory(path=si-diffusion-seed=0.traj, current_timestep=64, last_trusted_timestep=64)\n",
"On traj 2/2\n",
"Examining trajectory 2 of 2\n",
"Trajectory has untrusted segment, auditing\n",
"Auditing trajectory\n",
"score < threshold (0.33761517140362796 < 0.5, marking recent segment as trusted\n",
"CascadeTrajectory(path=si-diffusion-seed=1.traj, current_timestep=128, last_trusted_timestep=128)\n",
"done 0 / 2\n",
"On traj 1/2\n",
"Last audit passed; trajectory complete\n",
"**********\n",
"Starting pass 5/10 of cascade loop\n",
"Currently 1 of 2 complete\n",
"Examining trajectory 1 of 2\n",
"Trajectory is trusted, advancing\n",
"Running ML-driven dynamics\n",
"CascadeTrajectory(path=si-diffusion-seed=0.traj, current_timestep=128, last_trusted_timestep=64)\n",
"On traj 2/2\n",
"Traj is completed, continuing\n",
"done 1 / 2\n",
"On traj 1/2\n",
"Examining trajectory 2 of 2\n",
"Trajectory is done, continuing\n",
"**********\n",
"Starting pass 6/10 of cascade loop\n",
"Currently 1 of 2 complete\n",
"Examining trajectory 1 of 2\n",
"Trajectory has untrusted segment, auditing\n",
"Auditing trajectory\n",
"score > threshold (0.940523264489604 > 0.5), running audit calculations and dropping untrusted segment\n",
"CascadeTrajectory(path=si-diffusion-seed=0.traj, current_timestep=64, last_trusted_timestep=64)\n",
"On traj 2/2\n",
"Traj is completed, continuing\n",
"done 1 / 2\n",
"On traj 1/2\n",
"Examining trajectory 2 of 2\n",
"Trajectory is done, continuing\n",
"**********\n",
"Starting pass 7/10 of cascade loop\n",
"Currently 1 of 2 complete\n",
"Examining trajectory 1 of 2\n",
"Trajectory is trusted, advancing\n",
"Running ML-driven dynamics\n",
"CascadeTrajectory(path=si-diffusion-seed=0.traj, current_timestep=128, last_trusted_timestep=64)\n",
"On traj 2/2\n",
"Traj is completed, continuing\n",
"done 1 / 2\n",
"On traj 1/2\n",
"Examining trajectory 2 of 2\n",
"Trajectory is done, continuing\n",
"**********\n",
"Starting pass 8/10 of cascade loop\n",
"Currently 1 of 2 complete\n",
"Examining trajectory 1 of 2\n",
"Trajectory has untrusted segment, auditing\n",
"Auditing trajectory\n",
"score < threshold (0.09367476782809248 < 0.5, marking recent segment as trusted\n",
"CascadeTrajectory(path=si-diffusion-seed=0.traj, current_timestep=128, last_trusted_timestep=128)\n",
"On traj 2/2\n",
"Traj is completed, continuing\n",
"done 1 / 2\n",
"On traj 1/2\n",
"Traj is completed, continuing\n",
"On traj 2/2\n",
"Traj is completed, continuing\n",
"done 2 / 2\n"
"Last audit passed; trajectory complete\n",
"Examining trajectory 2 of 2\n",
"Trajectory is done, continuing\n",
"Finished all trajectories in 8 iterations\n"
]
}
],
Expand All @@ -296,64 +293,23 @@
" starting=atoms.copy()) for s in seeds]\n",
"# notably, right now, the seeds have no effect since our dynamics are NVE\n",
"\n",
"cascade = SerialCascadeRunner(\n",
" trajectories=trajectories,\n",
" total_steps=128,\n",
" increment_steps=64,\n",
" uq_threshold=0.5,\n",
" auditor=RandomAuditor(random_state=42),\n",
" learner=learner,\n",
" model=model,\n",
" calculator=calc,\n",
" dyn_cls=VelocityVerlet,\n",
" train_kws=dict(device='cpu', num_epochs=1),\n",
" max_train=10,\n",
" val_frac=0.1,\n",
" training_file='train.traj',\n",
")\n",
"\n",
"total_steps = 128 # how long will our final trajectories be\n",
"increment_steps = 64 # how many steps to run with ML at a time\n",
"\n",
"# audits are random\n",
"auditor = RandomAuditor(random_state=42)\n",
"threshold = 0.5 # this is the 'score' threshold on the auditor\n",
"\n",
"done = False\n",
"i = 0 # track while loop iterations\n",
"max_iter = 10 # dont go above this\n",
"while not done:\n",
" \n",
" done_ctr = 0 # count how many trajectories are done\n",
" \n",
" for j, traj in enumerate(trajectories):\n",
" \n",
" ## Check if this trajectory is done\n",
" print(f'On traj {j+1}/{len(trajectories)}')\n",
" if traj.last_trusted_timestep == total_steps: \n",
" done_ctr += 1\n",
" print('Traj is completed, continuing')\n",
" continue\n",
"\n",
" \n",
" ## if we've advanced past a trusted segment, lets audit it\n",
" if traj.current_timestep > traj.last_trusted_timestep: \n",
" print('Auditing trajectory')\n",
" segment = traj.get_untrusted_segment()\n",
" score, audit_frames = auditor.audit(segment, n_audits=32)\n",
" if score > threshold: \n",
" print(f'score > threshold ({score} > {threshold}), running audit calculations and dropping untrusted segment')\n",
" segment = apply_calculator(calc, segment)\n",
" # todo: add to training set\n",
" traj.trim_untrusted_segment()\n",
" else:\n",
" print(f'score < threshold ({score} < {threshold}, marking recent segment as trusted')\n",
" traj.last_trusted_timestep = traj.current_timestep\n",
" traj.trusted = traj.current\n",
" \n",
" # otherwise we can run the ML-driven dynamics \n",
" else:\n",
" # then we run dynamics\n",
" print('Running ML-driven dynamics')\n",
" atoms = traj.trusted.copy()\n",
" atoms.calc = learner.make_calculator(model, device='cpu')\n",
" dyn = VelocityVerlet(atoms=atoms,\n",
" timestep=1*units.fs,\n",
" trajectory=TrajectoryWriter(traj.path, mode='a')\n",
" )\n",
" dyn.run(increment_steps)\n",
" traj.current_timestep += increment_steps\n",
" traj.current = atoms\n",
" print(traj)\n",
" \n",
" i += 1\n",
" print(f'done {done_ctr} / {len(trajectories)}')\n",
" done = done_ctr == len(trajectories) or i == max_iter"
"cascade.run(max_iter=10)"
]
},
{
Expand All @@ -365,9 +321,9 @@
"\n",
"This is great, next steps: \n",
"- [ ] diagram out current/trusted logic\n",
"- [ ] add training\n",
"- [ ] break into functions/classes WIP\n",
"- [ ] add tests WIP\n",
"- [x] add training\n",
"- [x] break into functions/classes WIP\n",
"- [x] add tests WIP\n",
"- [ ] add logging"
]
},
Expand Down
Loading

0 comments on commit 39b84f0

Please sign in to comment.