From ab7a326e0c5cdb0a7177e7dd84f86d97a89767c0 Mon Sep 17 00:00:00 2001 From: miketynes Date: Fri, 15 Nov 2024 12:58:50 -0600 Subject: [PATCH] extremely minimal viable prototype --- 3_cascade/1_prototype-cascade.ipynb | 382 ++++++++++++++-------------- 1 file changed, 194 insertions(+), 188 deletions(-) diff --git a/3_cascade/1_prototype-cascade.ipynb b/3_cascade/1_prototype-cascade.ipynb index 63d54a2..98df416 100644 --- a/3_cascade/1_prototype-cascade.ipynb +++ b/3_cascade/1_prototype-cascade.ipynb @@ -16,23 +16,39 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 1, "id": "9c49f1f6-3a41-4d70-8126-236faccc3f4d", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/mike/miniconda3/envs/cascade/lib/python3.11/site-packages/e3nn/o3/_wigner.py:10: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))\n", + "/home/mike/miniconda3/envs/cascade/lib/python3.11/site-packages/torchani/aev.py:16: UserWarning: cuaev not installed\n", + " warnings.warn(\"cuaev not installed\")\n", + "/home/mike/miniconda3/envs/cascade/lib/python3.11/site-packages/ignite/handlers/checkpoint.py:16: DeprecationWarning: `TorchScript` support for functional optimizers is deprecated and will be removed in a future PyTorch release. Consider using the `torch.compile` optimizer instead.\n", + " from torch.distributed.optim import ZeroRedundancyOptimizer\n" + ] + } + ], "source": [ "from glob import glob\n", "from pathlib import Path\n", + "from dataclasses import dataclass, field\n", "\n", + "\n", + "import ase\n", "from ase.io import read, write\n", - "from ase.io.trajectory import Trajectory\n", + "from ase.io.trajectory import Trajectory, TrajectoryWriter\n", "from ase import units\n", "from ase.md import MDLogger, VelocityVerlet\n", "import numpy as np\n", "from mace.calculators import mace_mp\n", "\n", "\n", - "from cascade.utils import canonicalize\n", + "from cascade.utils import canonicalize, apply_calculator\n", "from cascade.auditor import RandomAuditor\n", "from cascade.learning.torchani import TorchANI\n", "from cascade.learning.torchani.build import make_output_nets, make_aev_computer" @@ -76,20 +92,22 @@ "metadata": {}, "outputs": [ { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ - "/home/mike/miniconda3/envs/cascade/lib/python3.11/site-packages/mace/calculators/mace.py:128: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", - " torch.load(f=model_path, map_location=device)\n" + "Using Materials Project MACE for MACECalculator with /home/mike/.cache/mace/20231210mace128L0_energy_epoch249model\n", + "Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.\n", + "Default dtype float32 does not match model dtype float64, converting models to float32.\n" ] }, { - "name": "stdout", + "name": "stderr", "output_type": "stream", "text": [ - "Using Materials Project MACE for MACECalculator with /home/mike/.cache/mace/20231210mace128L0_energy_epoch249model\n", - "Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.\n", - "Default dtype float32 does not match model dtype float64, converting models to float32.\n" + "/home/mike/miniconda3/envs/cascade/lib/python3.11/site-packages/torch/cuda/__init__.py:128: UserWarning: CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:108.)\n", + " return torch._C._cuda_getDeviceCount() > 0\n", + "/home/mike/miniconda3/envs/cascade/lib/python3.11/site-packages/mace/calculators/mace.py:128: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " torch.load(f=model_path, map_location=device)\n" ] } ], @@ -110,7 +128,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "id": "a8f2e17c-c414-42da-ba43-3ebe7a49176d", "metadata": {}, "outputs": [], @@ -120,7 +138,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 5, "id": "29fab573-4c07-49d0-aeed-db13b6c8f7df", "metadata": {}, "outputs": [], @@ -141,11 +159,12 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "id": "e62cbd7c-4a4e-46a9-8d70-b986a95ac792", "metadata": {}, "outputs": [], "source": [ + "@dataclass\n", "class CascadeTrajectory:\n", " \"\"\"A class to encasplulate a cascade trajectory\n", "\n", @@ -153,11 +172,21 @@ " so we know where to start sampling from (e.g., after the last trusted timestep)\n", " \"\"\"\n", "\n", - " def __init__(self,\n", - " path: str,\n", - " last_trusted_timestep: int = 0):\n", + "\n", + " def __init__(self, \n", + " path: str, \n", + " starting: ase.Atoms = None):\n", " self.path = path\n", - " self.last_trusted_timestep = last_trusted_timestep\n", + " self.starting = starting\n", + " \n", + " if self.starting is not None:\n", + " write(self.path, self.starting)\n", + " else:\n", + " self.starting = read(self.path)\n", + " \n", + " self.current = starting\n", + " self.current_timestep = 0\n", + " self.last_trusted_timestep = 0\n", " \n", " def read(self, index=':', *args, **kwargs):\n", " return read(self.path, *args, index=index, **kwargs)\n", @@ -168,6 +197,10 @@ " def trim_untrusted_segment(self):\n", " # todo: is there a way to do this without loading into memory?\n", " write(self.path, read(self.path, index=f':{self.last_trusted_timestep+1}'))\n", + " self.current_timestep = self.last_trusted_timestep\n", + "\n", + " def __repr__(self): \n", + " return f\"CascadeTrajectory(path={self.path}, current_timestep={self.current_timestep}, last_trusted_timestep={self.last_trusted_timestep})\"\n", " " ] }, @@ -183,7 +216,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "id": "1b225651-b0ee-46be-ac8d-8b55d8ea4c83", "metadata": {}, "outputs": [], @@ -193,7 +226,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "id": "766de0f0-2df7-4dc2-813a-06e414047c97", "metadata": {}, "outputs": [], @@ -203,7 +236,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "id": "10c2c251-dcc8-4e1d-b8fc-f4e51fd9f552", "metadata": {}, "outputs": [ @@ -214,7 +247,7 @@ " Atoms(symbols='Si63', pbc=True, cell=[10.86, 10.86, 10.86])]" ] }, - "execution_count": 10, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -225,7 +258,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "id": "f0e8a83c-8c95-44f3-ab42-02806c1ed390", "metadata": {}, "outputs": [ @@ -235,7 +268,7 @@ "[Atoms(symbols='Si63', pbc=True, cell=[10.86, 10.86, 10.86])]" ] }, - "execution_count": 11, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -246,7 +279,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "id": "60a186dd-037b-4395-8284-5933d051c3b4", "metadata": {}, "outputs": [ @@ -256,7 +289,7 @@ "[Atoms(symbols='Si63', pbc=True, cell=[10.86, 10.86, 10.86])]" ] }, - "execution_count": 12, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -267,7 +300,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "id": "004ff4db-6f18-43dc-a441-e09fb570f7a1", "metadata": {}, "outputs": [], @@ -277,7 +310,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "id": "c17f0a18-f39d-4e52-8db3-dfe18b8f7423", "metadata": {}, "outputs": [ @@ -287,7 +320,7 @@ "[Atoms(symbols='Si63', pbc=True, cell=[10.86, 10.86, 10.86])]" ] }, - "execution_count": 14, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -302,36 +335,9 @@ "metadata": {}, "source": [] }, - { - "cell_type": "markdown", - "id": "60991038-6191-4bd5-aa3c-e30fe1a0c976", - "metadata": {}, - "source": [ - "## train initial models\n", - "I just can't stomach starting with comepletely untrained ani models" - ] - }, { "cell_type": "code", - "execution_count": 15, - "id": "3eae7ee0-bd2d-432d-8c78-1eebfb185569", - "metadata": {}, - "outputs": [], - "source": [ - "class CanonicalWriter():\n", - "\n", - " def __init__(self, path):\n", - " self.path = path\n", - "\n", - " def __call__(self): \n", - "\n", - " with Trajectory(self.path, mode='a') as traj: \n", - " traj.write(canonicalize(atoms))" - ] - }, - { - "cell_type": "code", - "execution_count": 16, + "execution_count": 14, "id": "76f8685a-1cfb-4579-8373-c707bb079dc6", "metadata": {}, "outputs": [], @@ -341,175 +347,175 @@ }, { "cell_type": "code", - "execution_count": 17, - "id": "c4b1a101-fd95-4aa8-b94d-f40a5a4274ec", - "metadata": {}, - "outputs": [], - "source": [ - "n_training_frames = 128\n", - "atoms.calc = calc\n", - "dynamics = VelocityVerlet(atoms, timestep=1*units.fs)" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "9f185353-35ad-4a88-9041-2436d4ac7310", - "metadata": {}, - "outputs": [], - "source": [ - "md_logger = MDLogger(np, atoms, 'train.log', stress=True)\n", - "traj_writer = CanonicalWriter('train.traj')\n", - "dynamics.attach(md_logger)\n", - "dynamics.attach(traj_writer)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "cf5d2836-0cff-4822-bf9f-8e6479e3f225", + "execution_count": 15, + "id": "568bdb51-d2a5-4ff5-9b50-13d4bc111e5f", "metadata": {}, "outputs": [ { - "name": "stdout", + "name": "stderr", "output_type": "stream", "text": [ - "CPU times: user 9.1 s, sys: 5.31 s, total: 14.4 s\n", - "Wall time: 7.42 s\n" + "/home/mike/miniconda3/envs/cascade/lib/python3.11/site-packages/torchani/utils.py:158: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " self_energies = torch.tensor(self_energies, dtype=torch.double)\n" ] }, { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "On traj 1/2\n", + "Running ML-driven dynamics\n", + "CascadeTrajectory(path=si-diffusion-seed=0.traj, current_timestep=64, last_trusted_timestep=0)\n", + "On traj 2/2\n", + "Running ML-driven dynamics\n", + "CascadeTrajectory(path=si-diffusion-seed=1.traj, current_timestep=64, last_trusted_timestep=0)\n", + "done 0 / 2\n", + "On traj 1/2\n", + "Auditing trajectory\n", + "score < threshold (0.3745401188473625 < 0.5, marking recent segment as trusted\n", + "CascadeTrajectory(path=si-diffusion-seed=0.traj, current_timestep=64, last_trusted_timestep=64)\n", + "On traj 2/2\n", + "Auditing trajectory\n", + "score < threshold (0.034388521115218396 < 0.5, marking recent segment as trusted\n", + "CascadeTrajectory(path=si-diffusion-seed=1.traj, current_timestep=64, last_trusted_timestep=64)\n", + "done 0 / 2\n", + "On traj 1/2\n", + "Running ML-driven dynamics\n", + "CascadeTrajectory(path=si-diffusion-seed=0.traj, current_timestep=128, last_trusted_timestep=64)\n", + "On traj 2/2\n", + "Running ML-driven dynamics\n", + "CascadeTrajectory(path=si-diffusion-seed=1.traj, current_timestep=128, last_trusted_timestep=64)\n", + "done 0 / 2\n", + "On traj 1/2\n", + "Auditing trajectory\n", + "score > threshold (0.6688412526636073 > 0.5), running audit calculations and dropping untrusted segment\n", + "CascadeTrajectory(path=si-diffusion-seed=0.traj, current_timestep=64, last_trusted_timestep=64)\n", + "On traj 2/2\n", + "Auditing trajectory\n", + "score < threshold (0.33761517140362796 < 0.5, marking recent segment as trusted\n", + "CascadeTrajectory(path=si-diffusion-seed=1.traj, current_timestep=128, last_trusted_timestep=128)\n", + "done 0 / 2\n", + "On traj 1/2\n", + "Running ML-driven dynamics\n", + "CascadeTrajectory(path=si-diffusion-seed=0.traj, current_timestep=128, last_trusted_timestep=64)\n", + "On traj 2/2\n", + "Traj is completed, continuing\n", + "done 1 / 2\n", + "On traj 1/2\n", + "Auditing trajectory\n", + "score > threshold (0.940523264489604 > 0.5), running audit calculations and dropping untrusted segment\n", + "CascadeTrajectory(path=si-diffusion-seed=0.traj, current_timestep=64, last_trusted_timestep=64)\n", + "On traj 2/2\n", + "Traj is completed, continuing\n", + "done 1 / 2\n", + "On traj 1/2\n", + "Running ML-driven dynamics\n", + "CascadeTrajectory(path=si-diffusion-seed=0.traj, current_timestep=128, last_trusted_timestep=64)\n", + "On traj 2/2\n", + "Traj is completed, continuing\n", + "done 1 / 2\n", + "On traj 1/2\n", + "Auditing trajectory\n", + "score < threshold (0.09367476782809248 < 0.5, marking recent segment as trusted\n", + "CascadeTrajectory(path=si-diffusion-seed=0.traj, current_timestep=128, last_trusted_timestep=128)\n", + "On traj 2/2\n", + "Traj is completed, continuing\n", + "done 1 / 2\n", + "On traj 1/2\n", + "Traj is completed, continuing\n", + "On traj 2/2\n", + "Traj is completed, continuing\n", + "done 2 / 2\n" + ] } ], "source": [ - "%%time\n", - "dynamics.run(n_training_frames)" - ] - }, - { - "cell_type": "markdown", - "id": "a218b17b-b290-42d9-aa67-c545bd6b34c7", - "metadata": {}, - "source": [] - }, - { - "cell_type": "markdown", - "id": "fa54bc5d-929b-403d-a2aa-0309d8c196f0", - "metadata": {}, - "source": [ - "## Set up protype run" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d486f89f-3ee8-42ae-b4f8-15912b16fb1d", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "a0e17cd2-cf74-446e-8b21-94b16c24afc3", - "metadata": {}, - "outputs": [], - "source": [ - "seeds = [0, 1]" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "568bdb51-d2a5-4ff5-9b50-13d4bc111e5f", - "metadata": {}, - "outputs": [], - "source": [ - "total_steps = 128\n", - "increment_steps = 64\n", + "seeds = [0, 1]\n", "\n", - "# while not done:\n", - "# pass_ix = 1\n", - " \n", - "# # set up the directory to hold the trajectory for this pass\n", - "# run_dir = Path(f'cascade-md') / name\n", - "# pass_dir = run_dir / f'chunk={chunk_ix}-pass={pass_ix}'\n", - "# pass_dir.mkdir(exist_ok=True, parents=True)\n", "\n", - "# # pull in initial conidtions or last frame from the most recent trusted chunk\n", - "# if chunk_ix == 1: \n", - "# atoms = initial_conditions[name]\n", - "# else:\n", - "# last_pass = chunk_passes[chunk_ix-1]\n", - "# atoms = read(Path(run_dir)/name/f'chunk={chunk_ix-1}-{last_pass}', \n", - "# index='-1')\n", + "trajectories = [CascadeTrajectory(path=f'si-diffusion-seed={s}.traj', \n", + " starting=atoms.copy()) for s in seeds]\n", + "\n", + "done = False\n", + "total_steps = 128\n", + "increment_steps = 64 # how many steps to run with ML at a time\n", + "done = False\n", + "max_iter = 10\n", + "i = 0\n", + "auditor = RandomAuditor(random_state=42)\n", + "threshold = 0.5\n", "\n", - "# # we save the trajectory in chunks, inluding every pass at simulating that chunk\n", - "# logfile = str(pass_dir / 'md.log')\n", - "# trajfile = str(pass_dir / 'md.traj')\n", + "while not done:\n", + " \n", + " done_ctr = 0 # count how many trajectories are done\n", " \n", - "# # setup the ml-driven dynamics\n", - "# atoms.calc = calc_ml\n", - "# dyn = NPT(atoms,\n", - "# timestep=0.5 * units.fs,\n", - "# temperature_K=298,\n", - "# ttime=100 * units.fs,\n", - "# pfactor=0.01,\n", - "# externalstress=0,\n", - "# logfile=logfile,\n", - "# trajectory=trajfile,\n", - "# append_trajectory=False)\n", - "# # timestep indexing\n", - "# # start = (chunk_ix-1) * chunk_size # the actual starting timestep\n", - "# # stop = min(chunk_size, chunk_size*chunk_ix)\n", - "# # there is probably a nice mathy way to do this\n", - "# resulting_steps = chunk_ix * chunk_size # how many total timesteps will be achieved\n", - "# if resulting_steps < total_steps: \n", - "# chunk_steps = chunk_size\n", - "# else: \n", - "# chunk_steps = total_steps - ((chunk_ix-1)*chunk_size)\n", + " for j, traj in enumerate(trajectories):\n", + " \n", + " ## Check if this trajectory is done\n", + " print(f'On traj {j+1}/{len(trajectories)}')\n", + " if traj.last_trusted_timestep == total_steps: \n", + " done_ctr += 1\n", + " print('Traj is completed, continuing')\n", + " continue\n", "\n", - "# # run the dynamics for this chunk\n", - "# dyn.run(chunk_steps)\n", + " \n", + " ## if we've advanced past a trusted segment, lets audit it\n", + " if traj.current_timestep > traj.last_trusted_timestep: \n", + " print('Auditing trajectory')\n", + " segment = traj.get_untrusted_segment()\n", + " score, audit_frames = auditor.audit(segment, n_audits=32)\n", + " if score > threshold: \n", + " print(f'score > threshold ({score} > {threshold}), running audit calculations and dropping untrusted segment')\n", + " segment = apply_calculator(calc, segment)\n", + " traj.trim_untrusted_segment()\n", + " else:\n", + " print(f'score < threshold ({score} < {threshold}, marking recent segment as trusted')\n", + " traj.last_trusted_timestep = traj.current_timestep\n", "\n", - "# # read in the recent chunk\n", - "# chunk = read(trajfile)\n", - "# break" + " \n", + " # otherwise we can run the ML-driven dynamics \n", + " else:\n", + " # then we run dynamics\n", + " print('Running ML-driven dynamics')\n", + " traj.current.calc = learner.make_calculator(model, device='cpu')\n", + " dyn = VelocityVerlet(atoms=traj.current,\n", + " timestep=1*units.fs,\n", + " trajectory=TrajectoryWriter(traj.path, mode='a')\n", + " )\n", + " dyn.run(increment_steps)\n", + " traj.current_timestep += increment_steps\n", + " print(traj)\n", + " \n", + " i += 1\n", + " print(f'done {done_ctr} / {len(trajectories)}')\n", + " done = done_ctr == len(trajectories) or i == max_iter" ] }, { "cell_type": "code", - "execution_count": 36, - "id": "cf4d7556-2614-4667-8c58-6556d8ec2a78", + "execution_count": 19, + "id": "07448032-737c-49ac-9958-2829b95b841e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "0.5192068632195513" + "" ] }, - "execution_count": 36, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "np.random.RandomState(None).uniform(0, 1)" + "[len(t.read()), trajectories)" ] }, { "cell_type": "code", "execution_count": null, - "id": "07448032-737c-49ac-9958-2829b95b841e", + "id": "b19a2812-cc62-4bce-9c63-d8b6ba5b1bd4", "metadata": {}, "outputs": [], "source": []