From 39b84f046b78534f54df932adb608d900dd7a6e7 Mon Sep 17 00:00:00 2001 From: miketynes Date: Tue, 3 Dec 2024 16:00:39 -0600 Subject: [PATCH] factor out cascade loop into runner class --- 3_cascade/1_prototype-cascade.ipynb | 214 +++++++++++----------------- cascade/runner.py | 131 +++++++++++++++++ 2 files changed, 216 insertions(+), 129 deletions(-) create mode 100644 cascade/runner.py diff --git a/3_cascade/1_prototype-cascade.ipynb b/3_cascade/1_prototype-cascade.ipynb index caab114..e9bc16e 100644 --- a/3_cascade/1_prototype-cascade.ipynb +++ b/3_cascade/1_prototype-cascade.ipynb @@ -59,7 +59,11 @@ "from cascade.utils import canonicalize, apply_calculator\n", "from cascade.auditor import RandomAuditor\n", "from cascade.learning.torchani import TorchANI\n", - "from cascade.learning.torchani.build import make_output_nets, make_aev_computer" + "from cascade.learning.torchani.build import make_output_nets, make_aev_computer\n", + "from cascade.runner import SerialCascadeRunner\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2" ] }, { @@ -99,14 +103,6 @@ "id": "08a8ef58-4ee3-4686-9ff4-3942211e23c2", "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Using Materials Project MACE for MACECalculator with /home/mike/.cache/mace/20231210mace128L0_energy_epoch249model\n", - "Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.\n" - ] - }, { "name": "stderr", "output_type": "stream", @@ -119,6 +115,8 @@ "name": "stdout", "output_type": "stream", "text": [ + "Using Materials Project MACE for MACECalculator with /home/mike/.cache/mace/20231210mace128L0_energy_epoch249model\n", + "Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.\n", "Default dtype float32 does not match model dtype float64, converting models to float32.\n" ] } @@ -175,21 +173,6 @@ "## Minimum viable cascasde loop" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "b325e1fd-2737-4650-8808-e1fb2b5b2810", - "metadata": {}, - "outputs": [], - "source": [ - "class CascadeThinker:\n", - "\n", - " def __init__(self,\n", - " total_steps: int,\n", - " increment_steps: int,\n", - " auditor: " - ] - }, { "cell_type": "code", "execution_count": null, @@ -208,7 +191,11 @@ "name": "stdout", "output_type": "stream", "text": [ - "On traj 1/2\n", + "**********\n", + "Starting pass 1/10 of cascade loop\n", + "Currently 0 of 2 complete\n", + "Examining trajectory 1 of 2\n", + "Trajectory is trusted, advancing\n", "Running ML-driven dynamics\n" ] }, @@ -224,67 +211,77 @@ "name": "stdout", "output_type": "stream", "text": [ - "CascadeTrajectory(path=si-diffusion-seed=0.traj, current_timestep=64, last_trusted_timestep=0)\n", - "On traj 2/2\n", + "Examining trajectory 2 of 2\n", + "Trajectory is trusted, advancing\n", "Running ML-driven dynamics\n", - "CascadeTrajectory(path=si-diffusion-seed=1.traj, current_timestep=64, last_trusted_timestep=0)\n", - "done 0 / 2\n", - "On traj 1/2\n", + "**********\n", + "Starting pass 2/10 of cascade loop\n", + "Currently 0 of 2 complete\n", + "Examining trajectory 1 of 2\n", + "Trajectory has untrusted segment, auditing\n", "Auditing trajectory\n", "score < threshold (0.3745401188473625 < 0.5, marking recent segment as trusted\n", - "CascadeTrajectory(path=si-diffusion-seed=0.traj, current_timestep=64, last_trusted_timestep=64)\n", - "On traj 2/2\n", + "Examining trajectory 2 of 2\n", + "Trajectory has untrusted segment, auditing\n", "Auditing trajectory\n", "score < threshold (0.034388521115218396 < 0.5, marking recent segment as trusted\n", - "CascadeTrajectory(path=si-diffusion-seed=1.traj, current_timestep=64, last_trusted_timestep=64)\n", - "done 0 / 2\n", - "On traj 1/2\n", + "**********\n", + "Starting pass 3/10 of cascade loop\n", + "Currently 0 of 2 complete\n", + "Examining trajectory 1 of 2\n", + "Trajectory is trusted, advancing\n", "Running ML-driven dynamics\n", - "CascadeTrajectory(path=si-diffusion-seed=0.traj, current_timestep=128, last_trusted_timestep=64)\n", - "On traj 2/2\n", + "Examining trajectory 2 of 2\n", + "Trajectory is trusted, advancing\n", "Running ML-driven dynamics\n", - "CascadeTrajectory(path=si-diffusion-seed=1.traj, current_timestep=128, last_trusted_timestep=64)\n", - "done 0 / 2\n", - "On traj 1/2\n", + "**********\n", + "Starting pass 4/10 of cascade loop\n", + "Currently 0 of 2 complete\n", + "Examining trajectory 1 of 2\n", + "Trajectory has untrusted segment, auditing\n", "Auditing trajectory\n", "score > threshold (0.6688412526636073 > 0.5), running audit calculations and dropping untrusted segment\n", - "CascadeTrajectory(path=si-diffusion-seed=0.traj, current_timestep=64, last_trusted_timestep=64)\n", - "On traj 2/2\n", + "Examining trajectory 2 of 2\n", + "Trajectory has untrusted segment, auditing\n", "Auditing trajectory\n", "score < threshold (0.33761517140362796 < 0.5, marking recent segment as trusted\n", - "CascadeTrajectory(path=si-diffusion-seed=1.traj, current_timestep=128, last_trusted_timestep=128)\n", - "done 0 / 2\n", - "On traj 1/2\n", + "Last audit passed; trajectory complete\n", + "**********\n", + "Starting pass 5/10 of cascade loop\n", + "Currently 1 of 2 complete\n", + "Examining trajectory 1 of 2\n", + "Trajectory is trusted, advancing\n", "Running ML-driven dynamics\n", - "CascadeTrajectory(path=si-diffusion-seed=0.traj, current_timestep=128, last_trusted_timestep=64)\n", - "On traj 2/2\n", - "Traj is completed, continuing\n", - "done 1 / 2\n", - "On traj 1/2\n", + "Examining trajectory 2 of 2\n", + "Trajectory is done, continuing\n", + "**********\n", + "Starting pass 6/10 of cascade loop\n", + "Currently 1 of 2 complete\n", + "Examining trajectory 1 of 2\n", + "Trajectory has untrusted segment, auditing\n", "Auditing trajectory\n", "score > threshold (0.940523264489604 > 0.5), running audit calculations and dropping untrusted segment\n", - "CascadeTrajectory(path=si-diffusion-seed=0.traj, current_timestep=64, last_trusted_timestep=64)\n", - "On traj 2/2\n", - "Traj is completed, continuing\n", - "done 1 / 2\n", - "On traj 1/2\n", + "Examining trajectory 2 of 2\n", + "Trajectory is done, continuing\n", + "**********\n", + "Starting pass 7/10 of cascade loop\n", + "Currently 1 of 2 complete\n", + "Examining trajectory 1 of 2\n", + "Trajectory is trusted, advancing\n", "Running ML-driven dynamics\n", - "CascadeTrajectory(path=si-diffusion-seed=0.traj, current_timestep=128, last_trusted_timestep=64)\n", - "On traj 2/2\n", - "Traj is completed, continuing\n", - "done 1 / 2\n", - "On traj 1/2\n", + "Examining trajectory 2 of 2\n", + "Trajectory is done, continuing\n", + "**********\n", + "Starting pass 8/10 of cascade loop\n", + "Currently 1 of 2 complete\n", + "Examining trajectory 1 of 2\n", + "Trajectory has untrusted segment, auditing\n", "Auditing trajectory\n", "score < threshold (0.09367476782809248 < 0.5, marking recent segment as trusted\n", - "CascadeTrajectory(path=si-diffusion-seed=0.traj, current_timestep=128, last_trusted_timestep=128)\n", - "On traj 2/2\n", - "Traj is completed, continuing\n", - "done 1 / 2\n", - "On traj 1/2\n", - "Traj is completed, continuing\n", - "On traj 2/2\n", - "Traj is completed, continuing\n", - "done 2 / 2\n" + "Last audit passed; trajectory complete\n", + "Examining trajectory 2 of 2\n", + "Trajectory is done, continuing\n", + "Finished all trajectories in 8 iterations\n" ] } ], @@ -296,64 +293,23 @@ " starting=atoms.copy()) for s in seeds]\n", "# notably, right now, the seeds have no effect since our dynamics are NVE\n", "\n", + "cascade = SerialCascadeRunner(\n", + " trajectories=trajectories,\n", + " total_steps=128,\n", + " increment_steps=64,\n", + " uq_threshold=0.5,\n", + " auditor=RandomAuditor(random_state=42),\n", + " learner=learner,\n", + " model=model,\n", + " calculator=calc,\n", + " dyn_cls=VelocityVerlet,\n", + " train_kws=dict(device='cpu', num_epochs=1),\n", + " max_train=10,\n", + " val_frac=0.1,\n", + " training_file='train.traj',\n", + ")\n", "\n", - "total_steps = 128 # how long will our final trajectories be\n", - "increment_steps = 64 # how many steps to run with ML at a time\n", - "\n", - "# audits are random\n", - "auditor = RandomAuditor(random_state=42)\n", - "threshold = 0.5 # this is the 'score' threshold on the auditor\n", - "\n", - "done = False\n", - "i = 0 # track while loop iterations\n", - "max_iter = 10 # dont go above this\n", - "while not done:\n", - " \n", - " done_ctr = 0 # count how many trajectories are done\n", - " \n", - " for j, traj in enumerate(trajectories):\n", - " \n", - " ## Check if this trajectory is done\n", - " print(f'On traj {j+1}/{len(trajectories)}')\n", - " if traj.last_trusted_timestep == total_steps: \n", - " done_ctr += 1\n", - " print('Traj is completed, continuing')\n", - " continue\n", - "\n", - " \n", - " ## if we've advanced past a trusted segment, lets audit it\n", - " if traj.current_timestep > traj.last_trusted_timestep: \n", - " print('Auditing trajectory')\n", - " segment = traj.get_untrusted_segment()\n", - " score, audit_frames = auditor.audit(segment, n_audits=32)\n", - " if score > threshold: \n", - " print(f'score > threshold ({score} > {threshold}), running audit calculations and dropping untrusted segment')\n", - " segment = apply_calculator(calc, segment)\n", - " # todo: add to training set\n", - " traj.trim_untrusted_segment()\n", - " else:\n", - " print(f'score < threshold ({score} < {threshold}, marking recent segment as trusted')\n", - " traj.last_trusted_timestep = traj.current_timestep\n", - " traj.trusted = traj.current\n", - " \n", - " # otherwise we can run the ML-driven dynamics \n", - " else:\n", - " # then we run dynamics\n", - " print('Running ML-driven dynamics')\n", - " atoms = traj.trusted.copy()\n", - " atoms.calc = learner.make_calculator(model, device='cpu')\n", - " dyn = VelocityVerlet(atoms=atoms,\n", - " timestep=1*units.fs,\n", - " trajectory=TrajectoryWriter(traj.path, mode='a')\n", - " )\n", - " dyn.run(increment_steps)\n", - " traj.current_timestep += increment_steps\n", - " traj.current = atoms\n", - " print(traj)\n", - " \n", - " i += 1\n", - " print(f'done {done_ctr} / {len(trajectories)}')\n", - " done = done_ctr == len(trajectories) or i == max_iter" + "cascade.run(max_iter=10)" ] }, { @@ -365,9 +321,9 @@ "\n", "This is great, next steps: \n", "- [ ] diagram out current/trusted logic\n", - "- [ ] add training\n", - "- [ ] break into functions/classes WIP\n", - "- [ ] add tests WIP\n", + "- [x] add training\n", + "- [x] break into functions/classes WIP\n", + "- [x] add tests WIP\n", "- [ ] add logging" ] }, diff --git a/cascade/runner.py b/cascade/runner.py new file mode 100644 index 0000000..105dcf2 --- /dev/null +++ b/cascade/runner.py @@ -0,0 +1,131 @@ +"""Cascade Runners -- they run cascade""" + +from ase.calculators.calculator import Calculator +from ase.io.trajectory import TrajectoryWriter +from ase.io import read +from ase.md.md import MolecularDynamics +from ase import units +from sklearn.model_selection import train_test_split + +from cascade.learning.base import BaseLearnableForcefield +from cascade.auditor import BaseAuditor +from cascade.trajectory import CascadeTrajectory +from cascade.utils import apply_calculator + + +class SerialCascadeRunner: + + def __init__(self, + trajectories: list[CascadeTrajectory], + total_steps: int, + increment_steps: int, + uq_threshold: float, + auditor: BaseAuditor, + calculator: Calculator, + learner: BaseLearnableForcefield, + model: bytes, + training_file: str, + dyn_cls: type[MolecularDynamics], # I wonder if we could be even more generic + train_kws: dict = None, + val_frac: float = 0.1, + max_train: int = None + ): + self.trajectories = trajectories + self.total_steps = total_steps + self.increment_steps = increment_steps + self.uq_threshold = uq_threshold + self.auditor = auditor + self.calculator = calculator + self.learner = learner + self.model = model + self.training_file = training_file + self.dyn_cls = dyn_cls + self.train_kws = train_kws if train_kws is not None else {} + self.val_frac = val_frac + self.max_train = -max_train if max_train is not None else '' # store as negative for indexing + + @property + def n_trajectories(self): + return len(self.trajectories) + + def run(self, max_iter=None): + i = 0 # track while loop iterations + done_indices = [] + while True: + print('*'*10) + print(f'Starting pass {i+1}/{max_iter} of cascade loop') + print(f'Currently {len(done_indices)} of {self.n_trajectories} complete') + for j, traj in enumerate(self.trajectories): + print(f'Examining trajectory {j+1} of {self.n_trajectories}') + if j in done_indices: + print('Trajectory is done, continuing') + continue + # if we've advanced past a trusted segment, lets audit it + if traj.current_timestep > traj.last_trusted_timestep: + print('Trajectory has untrusted segment, auditing') + self._audit_untrusted_segment(traj) + + if traj.last_trusted_timestep == self.total_steps: + print('Last audit passed; trajectory complete') + done_indices.append(j) + # otherwise we can run the ML-driven dynamics + else: + print('Trajectory is trusted, advancing') + self._advance_trajectory(traj) + i += 1 + # self._update_model() + if len(done_indices) == self.n_trajectories: + print(f'Finished all trajectories in {i} iterations') + break + elif i == max_iter: + print('Hit max iterations, stopping') + break + + def _update_model(self): + # this will have to change quite a bit for the parallel version + print('Updating model') + train = read(self.training_file, index=f'{self.max_train}:') + print(f'read {len(train)} frames for training') + train, val = train_test_split(train, test_size=self.val_frac) + self.model, perf = self.learner.train(self.model, train, val, **self.train_kws) + + def _advance_trajectory(self, traj): + """Advance the trajectory under the current ML surrogate""" + print('Running ML-driven dynamics') + atoms = traj.trusted.copy() + atoms.calc = self.learner.make_calculator(self.model, device='cpu') + dyn = self.dyn_cls(atoms=atoms, + timestep=1*units.fs, + trajectory=TrajectoryWriter(traj.path, mode='a') + ) + dyn.run(self.increment_steps) + traj.current_timestep += self.increment_steps + traj.current = atoms + + def _audit_untrusted_segment(self, traj): + """Audits the untrusted segment of a trajectory + + If the score is above the threshold, the apply the reference calculator + and update the training set. Else mark the segment as trusted + """ + print('Auditing trajectory') + segment = traj.get_untrusted_segment() + score, audit_frames = self.auditor.audit(segment, n_audits=32) + if score > self.uq_threshold: + print(f'score > threshold ({score} > {self.uq_threshold}), running audit calculations and dropping untrusted segment') + + # apply the expensive calculations + segment = apply_calculator(self.calculator, segment) + + # save calculations to disk + writer = TrajectoryWriter(self.training_file, mode='a') + for atoms in segment: + writer.write(atoms) + + # remove the untrusted calculations + traj.trim_untrusted_segment() + + else: + print(f'score < threshold ({score} < {self.uq_threshold}, marking recent segment as trusted') + traj.last_trusted_timestep = traj.current_timestep + traj.trusted = traj.current