diff --git a/scripts/evaluate-generated-mofs/0_effect-of-training.ipynb b/scripts/evaluate-generated-mofs/0_effect-of-training.ipynb deleted file mode 100644 index 0ca3b6a4..00000000 --- a/scripts/evaluate-generated-mofs/0_effect-of-training.ipynb +++ /dev/null @@ -1,467 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "1719a71e-77b6-4edf-aea7-06656446ba2f", - "metadata": {}, - "source": [ - "# Are MOFs Generated by Later Models Better?\n", - "We periodically retrain the DiffLinker, and hope that the ones generated by later interations of the model are better." - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "9b732ebf-fe93-4610-bd44-9430b14f79fe", - "metadata": {}, - "outputs": [], - "source": [ - "%matplotlib inline\n", - "from matplotlib import pyplot as plt\n", - "from datetime import datetime\n", - "from pathlib import Path\n", - "from tqdm import tqdm\n", - "import pandas as pd\n", - "import json\n", - "import gzip" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "9efc8697-48e3-4fc5-8181-63ecb29c4fba", - "metadata": {}, - "outputs": [], - "source": [ - "run_dir = Path('../prod-runs/256-nodes/')" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "7d8cb69f-e7ba-4aee-8652-0fed12d6eaf8", - "metadata": {}, - "outputs": [], - "source": [ - "Path('figures').mkdir(exist_ok=True)" - ] - }, - { - "cell_type": "markdown", - "id": "88b14321-1e9c-48c0-872a-7fcbd734ac13", - "metadata": {}, - "source": [ - "## Load the Data from Disk\n", - "And make it compact\n" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "f510924d-3905-41db-aa68-2aecfbe8e75a", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "32341it [01:43, 312.62it/s]\n" - ] - } - ], - "source": [ - "records = []\n", - "with gzip.open(run_dir / 'mofs.json.gz', 'rt') as fp:\n", - " for line in tqdm(fp):\n", - " record = json.loads(line)\n", - "\n", - " # Remove structure data, label linkers by anchor\n", - " for k in ['md_trajectory', 'nodes', 'structure', '_id']:\n", - " del record[k]\n", - " for ligand in record.pop('ligands'):\n", - " record[f'ligand.{ligand[\"anchor_type\"]}'] = ligand\n", - " for k in ['xyz', 'dummy_element', 'anchor_type']:\n", - " del ligand[k]\n", - "\n", - " record['time'] = record.pop('times')['created']['$date']\n", - " records.append(pd.json_normalize(record))\n", - "records = pd.concat(records, ignore_index=True)" - ] - }, - { - "cell_type": "markdown", - "id": "b16f764b-dfa9-4928-99e6-898155e24e45", - "metadata": {}, - "source": [ - "Store the model versions" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "6445ea74-8c5b-43bf-97b1-c680a9d3f1a3", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
nametopologycatenationtimegas_storage.CO2structure_stability.uffligand.COO.nameligand.COO.smilesligand.COO.prompt_atomsligand.COO.metadata.model_versionligand.cyano.nameligand.cyano.smilesligand.cyano.prompt_atomsligand.cyano.metadata.model_version
0mof-00a88ea5NoneNone2024-04-13T23:18:21.584Z[10000.0, 0.0862266618]0.209704ligand-a1037294O=C([O-])c1ccc(C=C=[S+2]=C/[C-]=C/c2ccc(C(=O)O...[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13...0ligand-6536db2cN#Cc1ccc(C#CC#CC#Cc2ccc(C#N)cc2)cc1[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, ...0
1mof-89fd0977NoneNone2024-04-13T23:18:00.040Z[10000.0, 0.0756631463]0.228959ligand-0bb2fcf6[O-][C+](O)[C-]1[CH+][CH+][C+]([C][C][CH+][C][...[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13...0ligand-b53051b9[N-2][C][c+]1[cH-][cH+][c-]([C][C][CH+][C][C][...[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, ...0
2mof-1ef2070fNoneNone2024-04-13T23:17:59.318Z[10000.0, 0.2165945735]0.201553ligand-0fa742f0O=C([O-])c1ccc(C#C/[S+]=C/C#Cc2ccc(C(=O)O)cc2)cc1[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13...0ligand-145bcf48[N-2][C][c+]1[cH-][cH+][c-]([C][CH+][CH+][C][C...[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, ...0
3mof-b3a06d94NoneNone2024-04-13T23:18:01.291Z[10000.0, 0.0723458327]0.073941ligand-0fa742f0O=C([O-])c1ccc(C#C/[S+]=C/C#Cc2ccc(C(=O)O)cc2)cc1[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13...0ligand-818dbabc[N-2][C][c+]1[cH-][cH+][c-]([C][CH+][CH+][N-][...[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, ...0
4mof-9d97b1ecNoneNone2024-04-13T23:18:05.109Z[10000.0, 0.0714736096]0.260291ligand-0bb2fcf6[O-][C+](O)[C-]1[CH+][CH+][C+]([C][C][CH+][C][...[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13...0ligand-b9f216ab[N-2][C][c+]1[cH-][cH+][c-]([C][C][CH+][C][C][...[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, ...0
\n", - "
" - ], - "text/plain": [ - " name topology catenation time \\\n", - "0 mof-00a88ea5 None None 2024-04-13T23:18:21.584Z \n", - "1 mof-89fd0977 None None 2024-04-13T23:18:00.040Z \n", - "2 mof-1ef2070f None None 2024-04-13T23:17:59.318Z \n", - "3 mof-b3a06d94 None None 2024-04-13T23:18:01.291Z \n", - "4 mof-9d97b1ec None None 2024-04-13T23:18:05.109Z \n", - "\n", - " gas_storage.CO2 structure_stability.uff ligand.COO.name \\\n", - "0 [10000.0, 0.0862266618] 0.209704 ligand-a1037294 \n", - "1 [10000.0, 0.0756631463] 0.228959 ligand-0bb2fcf6 \n", - "2 [10000.0, 0.2165945735] 0.201553 ligand-0fa742f0 \n", - "3 [10000.0, 0.0723458327] 0.073941 ligand-0fa742f0 \n", - "4 [10000.0, 0.0714736096] 0.260291 ligand-0bb2fcf6 \n", - "\n", - " ligand.COO.smiles \\\n", - "0 O=C([O-])c1ccc(C=C=[S+2]=C/[C-]=C/c2ccc(C(=O)O... \n", - "1 [O-][C+](O)[C-]1[CH+][CH+][C+]([C][C][CH+][C][... \n", - "2 O=C([O-])c1ccc(C#C/[S+]=C/C#Cc2ccc(C(=O)O)cc2)cc1 \n", - "3 O=C([O-])c1ccc(C#C/[S+]=C/C#Cc2ccc(C(=O)O)cc2)cc1 \n", - "4 [O-][C+](O)[C-]1[CH+][CH+][C+]([C][C][CH+][C][... \n", - "\n", - " ligand.COO.prompt_atoms \\\n", - "0 [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13... \n", - "1 [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13... \n", - "2 [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13... \n", - "3 [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13... \n", - "4 [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13... \n", - "\n", - " ligand.COO.metadata.model_version ligand.cyano.name \\\n", - "0 0 ligand-6536db2c \n", - "1 0 ligand-b53051b9 \n", - "2 0 ligand-145bcf48 \n", - "3 0 ligand-818dbabc \n", - "4 0 ligand-b9f216ab \n", - "\n", - " ligand.cyano.smiles \\\n", - "0 N#Cc1ccc(C#CC#CC#Cc2ccc(C#N)cc2)cc1 \n", - "1 [N-2][C][c+]1[cH-][cH+][c-]([C][C][CH+][C][C][... \n", - "2 [N-2][C][c+]1[cH-][cH+][c-]([C][CH+][CH+][C][C... \n", - "3 [N-2][C][c+]1[cH-][cH+][c-]([C][CH+][CH+][N-][... \n", - "4 [N-2][C][c+]1[cH-][cH+][c-]([C][C][CH+][C][C][... \n", - "\n", - " ligand.cyano.prompt_atoms \\\n", - "0 [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, ... \n", - "1 [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, ... \n", - "2 [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, ... \n", - "3 [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, ... \n", - "4 [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, ... \n", - "\n", - " ligand.cyano.metadata.model_version \n", - "0 0 \n", - "1 0 \n", - "2 0 \n", - "3 0 \n", - "4 0 " - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "records.head()" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "70f48759-1d7e-4ec3-bbce-1112ebc144ca", - "metadata": {}, - "outputs": [], - "source": [ - "records['model_version'] = records[['ligand.cyano.metadata.model_version', 'ligand.COO.metadata.model_version']].max(axis=1)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "e9be34a8-308d-49de-a7aa-8e6ffa7bbb15", - "metadata": {}, - "outputs": [], - "source": [ - "records['time'] = records['time'].apply(lambda x: datetime.strptime(x, '%Y-%m-%dT%H:%M:%S.%fZ'))\n", - "records['walltime'] = (records['time'] - records['time'].min()).apply(lambda x: x.total_seconds())" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "67963d35-eccc-422f-a2d2-060b5796431b", - "metadata": {}, - "outputs": [], - "source": [ - "records.sort_values('walltime', inplace=True)" - ] - }, - { - "cell_type": "markdown", - "id": "54c1dcfe-036e-4393-964d-16cd8a746608", - "metadata": {}, - "source": [ - "## Plot Stability over Time\n", - "Do they get better or worse over time?" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "88d8f1b6-6343-4207-a635-047b4bf3615c", - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "fig, ax = plt.subplots(figsize=(3.5, 2.))\n", - "\n", - "sc = ax.scatter(records['walltime'] / 60, records['structure_stability.uff'] * 100, s=10,\n", - " c=records['model_version'])\n", - "\n", - "fig.colorbar(sc, label='Model Version')\n", - "\n", - "ax.set_xlabel('Time (m)')\n", - "ax.set_ylabel('Strain (%)')\n", - "\n", - "fig.tight_layout()\n", - "fig.savefig('figures/stability-over-time.png', dpi=320)" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "20859044-0e0e-44b9-b32d-c155782403bd", - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "fig, ax = plt.subplots(figsize=(3.5, 2.))\n", - "\n", - "records['cumulative_found'] = (records['structure_stability.uff'] < 0.1).cumsum()\n", - "sc = ax.scatter(\n", - " records['walltime'] / 60,\n", - " records['cumulative_found'],\n", - " s=10, c=records['model_version']\n", - ")\n", - "#ax.step(records['walltime'] / 60, count, zorder=-1, c='k', lw=1)\n", - "\n", - "fig.colorbar(sc, label='Model Version')\n", - "\n", - "ax.set_xlabel('Time (min)')\n", - "ax.set_ylabel('Stable Found')\n", - "\n", - "fig.tight_layout()\n", - "fig.savefig('figures/stability-over-time-step.png', dpi=320)" - ] - }, - { - "cell_type": "markdown", - "id": "b525d928-fff2-410c-87d7-3a5339e6dc65", - "metadata": {}, - "source": [ - "Save results" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "858ecab3-c4cc-458b-9472-9988d31b5a1d", - "metadata": {}, - "outputs": [], - "source": [ - "summary_dir = Path('summaries')\n", - "summary_dir.mkdir(exist_ok=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "d0e8f32e-2f60-48df-9c6f-9758a6a2e4f8", - "metadata": {}, - "outputs": [], - "source": [ - "records.to_csv(summary_dir / f'{run_dir.name}.csv.gz', index=False)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f7b64b08-6697-4069-95d3-522e0fa6b0d1", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.8" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/scripts/evaluate-generated-mofs/0_summarize-model-outcomes.ipynb b/scripts/evaluate-generated-mofs/0_summarize-model-outcomes.ipynb new file mode 100644 index 00000000..5c2caa26 --- /dev/null +++ b/scripts/evaluate-generated-mofs/0_summarize-model-outcomes.ipynb @@ -0,0 +1,467 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "1719a71e-77b6-4edf-aea7-06656446ba2f", + "metadata": {}, + "source": [ + "# Are MOFs Generated by Later Models Better?\n", + "We periodically retrain the DiffLinker, and hope that the ones generated by later interations of the model are better." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "9b732ebf-fe93-4610-bd44-9430b14f79fe", + "metadata": {}, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\n", + "from datetime import datetime\n", + "from pathlib import Path\n", + "from tqdm import tqdm\n", + "import pandas as pd\n", + "import json\n", + "import gzip" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "9efc8697-48e3-4fc5-8181-63ecb29c4fba", + "metadata": {}, + "outputs": [], + "source": [ + "run_dir = Path('../prod-runs/64-nodes_no-retrain_repeat-2/')" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "7d8cb69f-e7ba-4aee-8652-0fed12d6eaf8", + "metadata": {}, + "outputs": [], + "source": [ + "Path('figures').mkdir(exist_ok=True)" + ] + }, + { + "cell_type": "markdown", + "id": "88b14321-1e9c-48c0-872a-7fcbd734ac13", + "metadata": {}, + "source": [ + "## Load the Data from Disk\n", + "And make it compact\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "f510924d-3905-41db-aa68-2aecfbe8e75a", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "9770it [00:46, 207.89it/s]\n" + ] + } + ], + "source": [ + "records = []\n", + "with gzip.open(run_dir / 'mofs.json.gz', 'rt') as fp:\n", + " for line in tqdm(fp):\n", + " record = json.loads(line)\n", + "\n", + " # Remove structure data, label linkers by anchor\n", + " for k in ['md_trajectory', 'nodes', 'structure', '_id']:\n", + " del record[k]\n", + " for ligand in record.pop('ligands'):\n", + " record[f'ligand.{ligand[\"anchor_type\"]}'] = ligand\n", + " for k in ['xyz', 'dummy_element', 'anchor_type']:\n", + " del ligand[k]\n", + "\n", + " record['time'] = record.pop('times')['created']['$date']\n", + " records.append(pd.json_normalize(record))\n", + "records = pd.concat(records, ignore_index=True)" + ] + }, + { + "cell_type": "markdown", + "id": "b16f764b-dfa9-4928-99e6-898155e24e45", + "metadata": {}, + "source": [ + "Store the model versions" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "6445ea74-8c5b-43bf-97b1-c680a9d3f1a3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
nametopologycatenationtimegas_storage.CO2structure_stability.uffligand.COO.nameligand.COO.smilesligand.COO.prompt_atomsligand.COO.metadata.model_versionligand.cyano.nameligand.cyano.smilesligand.cyano.prompt_atomsligand.cyano.metadata.model_version
0mof-0c85fcdcNoneNone2024-10-05T18:34:31.991Z0.0707120.198765ligand-f36f085aO=C([O-])c1ccc(C=[N+]=C=CC#Cc2ccc(C(=O)O)cc2)cc1[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13...0ligand-b1ae0877[N-2][C][c+]1[cH-][cH+][c-]([C][CH+][CH+][C][C...[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, ...0
1mof-71f74056NoneNone2024-10-05T18:34:31.269Z0.0877370.201572ligand-f36f085aO=C([O-])c1ccc(C=[N+]=C=CC#Cc2ccc(C(=O)O)cc2)cc1[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13...0ligand-fd4b41eb[N-2][C][c+]1[cH-][cH+][c-]([C][CH+][CH+][C][C...[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, ...0
2mof-12eebf0fNoneNone2024-10-05T18:34:30.214Z0.0747910.288125ligand-f36f085aO=C([O-])c1ccc(C=[N+]=C=CC#Cc2ccc(C(=O)O)cc2)cc1[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13...0ligand-a3cb2664[N-2][C][c+]1[cH-][cH+][c-]([C][C][C][C][C][CH...[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, ...0
3mof-28b7f140NoneNone2024-10-05T18:34:30.073Z0.0945160.157505ligand-f36f085aO=C([O-])c1ccc(C=[N+]=C=CC#Cc2ccc(C(=O)O)cc2)cc1[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13...0ligand-7cc6d992[N-2][C][c+]1[cH-][cH+][c-]([C][C][CH+][C][C][...[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, ...0
4mof-93190b3cNoneNone2024-10-05T18:34:30.003Z0.0791150.230479ligand-f36f085aO=C([O-])c1ccc(C=[N+]=C=CC#Cc2ccc(C(=O)O)cc2)cc1[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13...0ligand-d6b66392[N-2][C][c+]1[cH-][cH+][c-]([C][C][CH+][CH+][C...[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, ...0
\n", + "
" + ], + "text/plain": [ + " name topology catenation time \\\n", + "0 mof-0c85fcdc None None 2024-10-05T18:34:31.991Z \n", + "1 mof-71f74056 None None 2024-10-05T18:34:31.269Z \n", + "2 mof-12eebf0f None None 2024-10-05T18:34:30.214Z \n", + "3 mof-28b7f140 None None 2024-10-05T18:34:30.073Z \n", + "4 mof-93190b3c None None 2024-10-05T18:34:30.003Z \n", + "\n", + " gas_storage.CO2 structure_stability.uff ligand.COO.name \\\n", + "0 0.070712 0.198765 ligand-f36f085a \n", + "1 0.087737 0.201572 ligand-f36f085a \n", + "2 0.074791 0.288125 ligand-f36f085a \n", + "3 0.094516 0.157505 ligand-f36f085a \n", + "4 0.079115 0.230479 ligand-f36f085a \n", + "\n", + " ligand.COO.smiles \\\n", + "0 O=C([O-])c1ccc(C=[N+]=C=CC#Cc2ccc(C(=O)O)cc2)cc1 \n", + "1 O=C([O-])c1ccc(C=[N+]=C=CC#Cc2ccc(C(=O)O)cc2)cc1 \n", + "2 O=C([O-])c1ccc(C=[N+]=C=CC#Cc2ccc(C(=O)O)cc2)cc1 \n", + "3 O=C([O-])c1ccc(C=[N+]=C=CC#Cc2ccc(C(=O)O)cc2)cc1 \n", + "4 O=C([O-])c1ccc(C=[N+]=C=CC#Cc2ccc(C(=O)O)cc2)cc1 \n", + "\n", + " ligand.COO.prompt_atoms \\\n", + "0 [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13... \n", + "1 [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13... \n", + "2 [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13... \n", + "3 [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13... \n", + "4 [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13... \n", + "\n", + " ligand.COO.metadata.model_version ligand.cyano.name \\\n", + "0 0 ligand-b1ae0877 \n", + "1 0 ligand-fd4b41eb \n", + "2 0 ligand-a3cb2664 \n", + "3 0 ligand-7cc6d992 \n", + "4 0 ligand-d6b66392 \n", + "\n", + " ligand.cyano.smiles \\\n", + "0 [N-2][C][c+]1[cH-][cH+][c-]([C][CH+][CH+][C][C... \n", + "1 [N-2][C][c+]1[cH-][cH+][c-]([C][CH+][CH+][C][C... \n", + "2 [N-2][C][c+]1[cH-][cH+][c-]([C][C][C][C][C][CH... \n", + "3 [N-2][C][c+]1[cH-][cH+][c-]([C][C][CH+][C][C][... \n", + "4 [N-2][C][c+]1[cH-][cH+][c-]([C][C][CH+][CH+][C... \n", + "\n", + " ligand.cyano.prompt_atoms \\\n", + "0 [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, ... \n", + "1 [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, ... \n", + "2 [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, ... \n", + "3 [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, ... \n", + "4 [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12, ... \n", + "\n", + " ligand.cyano.metadata.model_version \n", + "0 0 \n", + "1 0 \n", + "2 0 \n", + "3 0 \n", + "4 0 " + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "records.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "70f48759-1d7e-4ec3-bbce-1112ebc144ca", + "metadata": {}, + "outputs": [], + "source": [ + "records['model_version'] = records[['ligand.cyano.metadata.model_version', 'ligand.COO.metadata.model_version']].max(axis=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "e9be34a8-308d-49de-a7aa-8e6ffa7bbb15", + "metadata": {}, + "outputs": [], + "source": [ + "records['time'] = records['time'].apply(lambda x: datetime.strptime(x[:x.index(\".\")] + \"Z\" if \".\" in x else x, '%Y-%m-%dT%H:%M:%SZ'))\n", + "records['walltime'] = (records['time'] - records['time'].min()).apply(lambda x: x.total_seconds())" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "67963d35-eccc-422f-a2d2-060b5796431b", + "metadata": {}, + "outputs": [], + "source": [ + "records.sort_values('walltime', inplace=True)" + ] + }, + { + "cell_type": "markdown", + "id": "54c1dcfe-036e-4393-964d-16cd8a746608", + "metadata": {}, + "source": [ + "## Plot Stability over Time\n", + "Do they get better or worse over time?" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "88d8f1b6-6343-4207-a635-047b4bf3615c", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(3.5, 2.))\n", + "\n", + "sc = ax.scatter(records['walltime'] / 60, records['structure_stability.uff'] * 100, s=10,\n", + " c=records['model_version'])\n", + "\n", + "fig.colorbar(sc, label='Model Version')\n", + "\n", + "ax.set_xlabel('Time (m)')\n", + "ax.set_ylabel('Strain (%)')\n", + "\n", + "fig.tight_layout()\n", + "fig.savefig('figures/stability-over-time.png', dpi=320)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "20859044-0e0e-44b9-b32d-c155782403bd", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(3.5, 2.))\n", + "\n", + "records['cumulative_found'] = (records['structure_stability.uff'] < 0.1).cumsum()\n", + "sc = ax.scatter(\n", + " records['walltime'] / 60,\n", + " records['cumulative_found'],\n", + " s=10, c=records['model_version']\n", + ")\n", + "#ax.step(records['walltime'] / 60, count, zorder=-1, c='k', lw=1)\n", + "\n", + "fig.colorbar(sc, label='Model Version')\n", + "\n", + "ax.set_xlabel('Time (min)')\n", + "ax.set_ylabel('Stable Found')\n", + "\n", + "fig.tight_layout()\n", + "fig.savefig('figures/stability-over-time-step.png', dpi=320)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "b525d928-fff2-410c-87d7-3a5339e6dc65", + "metadata": {}, + "source": [ + "Save results" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "858ecab3-c4cc-458b-9472-9988d31b5a1d", + "metadata": {}, + "outputs": [], + "source": [ + "summary_dir = Path('summaries')\n", + "summary_dir.mkdir(exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "d0e8f32e-2f60-48df-9c6f-9758a6a2e4f8", + "metadata": {}, + "outputs": [], + "source": [ + "records.to_csv(summary_dir / f'{run_dir.name}.csv.gz', index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f7b64b08-6697-4069-95d3-522e0fa6b0d1", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/scripts/evaluate-generated-mofs/1_effect-of-scale.ipynb b/scripts/evaluate-generated-mofs/1_effect-of-scale.ipynb index 5aa1a03b..b8772979 100644 --- a/scripts/evaluate-generated-mofs/1_effect-of-scale.ipynb +++ b/scripts/evaluate-generated-mofs/1_effect-of-scale.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 21, "id": "a838a8d1-1402-4371-a3d6-54d63e91f034", "metadata": {}, "outputs": [], @@ -19,6 +19,7 @@ "%matplotlib inline\n", "from matplotlib import pyplot as plt\n", "from pathlib import Path\n", + "from itertools import chain\n", "import pandas as pd\n", "import numpy as np" ] @@ -34,16 +35,25 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 22, "id": "27ca604d-d7d1-421d-b636-29145e3a0807", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 6 runs: 32 32 64 128 256 450\n" + ] + } + ], "source": [ "summaries = []\n", - "for path in Path('summaries').glob('*csv.gz'):\n", + "for path in chain(Path('summaries').glob('*-nodes.csv.gz'), Path('summaries').glob('*-nodes_repeat-*.csv.gz')):\n", " count = int(path.name.split(\"-\")[0])\n", " summaries.append([count, pd.read_csv(path)])\n", - "summaries.sort(key=lambda x: x[0])" + "summaries.sort(key=lambda x: x[0])\n", + "print(f'Found {len(summaries)} runs:', \" \".join(str(x[0]) for x in summaries))" ] }, { @@ -57,42 +67,58 @@ }, { "cell_type": "code", - "execution_count": 3, - "id": "cd067d19-e292-4cdd-9bc4-cd754b49532a", + "execution_count": 25, + "id": "e6e85eea-4df6-412f-acf0-4d2b86fa88e2", "metadata": {}, "outputs": [], "source": [ + "sizes = sorted(set(x[0] for x in summaries))\n", "cmap = plt.get_cmap('copper_r')\n", - "steps = np.linspace(0.1, 1., len(summaries))" + "steps = np.linspace(0.2, 1., len(sizes))\n", + "colors = dict((size, cmap(step)) for size, step in zip(sizes, steps))" ] }, { "cell_type": "code", - "execution_count": 4, - "id": "a0c741fc-b772-46fc-9fec-dcbe9971008f", + "execution_count": 26, + "id": "f9b000b1-bf87-4812-a001-50189252f639", "metadata": {}, "outputs": [ { "data": { - "image/png": "", "text/plain": [ - "
" + "{32: (0.9882350615917502, 0.62496, 0.398, 1.0),\n", + " 64: (0.7411762961938126, 0.46871999999999997, 0.2985, 1.0),\n", + " 128: (0.49411753079587517, 0.31248000000000004, 0.199, 1.0),\n", + " 256: (0.24705876539793759, 0.15623999999999993, 0.09949999999999998, 1.0),\n", + " 450: (0.0, 0.0, 0.0, 1.0)}" ] }, + "execution_count": 26, "metadata": {}, - "output_type": "display_data" + "output_type": "execute_result" } ], + "source": [ + "colors" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "a0c741fc-b772-46fc-9fec-dcbe9971008f", + "metadata": {}, + "outputs": [], "source": [ "fig, ax = plt.subplots(figsize=(3.5, 2.))\n", "\n", - "for (count, summary), step in zip(summaries, steps):\n", + "for count, summary in summaries:\n", " summary = summary.query(f'walltime > {15 * 60}')\n", " ax.plot(\n", " summary['walltime'] / 60,\n", " summary['cumulative_found'] / count / summary['walltime'] * 3600,\n", " '-',\n", - " color=cmap(step),\n", + " color=colors[count],\n", " label=f'N={count}'\n", " )\n", "ax.legend(fontsize=6)\n", @@ -105,10 +131,57 @@ "fig.savefig('figures/stable-found-per-node-hour.pdf')" ] }, + { + "cell_type": "code", + "execution_count": 30, + "id": "085439d4-1eaa-4102-bd74-23661a27ec69", + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(3.5, 2.))\n", + "\n", + "lowest_slope = np.mean(\n", + " [summary['cumulative_found'].iloc[-1] / summary['walltime'].iloc[-1] * 3600\n", + " for count, summary in summaries if count == min(colors.keys())]\n", + ")\n", + "done = set()\n", + "for count, summary in summaries:\n", + " ax.plot(\n", + " summary['walltime'] / 3600,\n", + " summary['cumulative_found'],\n", + " '-',\n", + " color=colors[count],\n", + " alpha=0.8,\n", + " label=f'N={count}' if count not in done else None\n", + " )\n", + " done.add(count)\n", + "\n", + "ax.set_xlim(0, 2.8)\n", + "for count, summary in summaries:\n", + " ax.plot(ax.get_xlim(), np.multiply(ax.get_xlim(), lowest_slope * (count - 1) / 31), '--', lw=1, color=colors[count])\n", + " \n", + "ax.legend(fontsize=6)\n", + "\n", + "ax.set_xlabel('Walltime (hr)')\n", + "ax.set_ylabel('Stable MOFs Found')\n", + "\n", + "fig.tight_layout()\n", + "fig.savefig('figures/stable-found-per-hour.png', dpi=320)\n", + "fig.savefig('figures/stable-found-per-hour.pdf')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a98ccfb2-4c17-4e91-8a77-dbea2bde6056", + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, - "id": "03e91217-cb40-46a6-b143-8e2ff6d44694", + "id": "25317804-b828-4f0a-ab5e-2ab66f61ea86", "metadata": {}, "outputs": [], "source": [] diff --git a/scripts/evaluate-generated-mofs/2_effect-of-training.ipynb b/scripts/evaluate-generated-mofs/2_effect-of-training.ipynb new file mode 100644 index 00000000..9bbe8356 --- /dev/null +++ b/scripts/evaluate-generated-mofs/2_effect-of-training.ipynb @@ -0,0 +1,232 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5211ca75-c0ed-4e81-ae37-2a21533656f1", + "metadata": {}, + "source": [ + "# Evaluate the Effect of Training \n", + "We can assess whether retraining Difflinker leads to improved performance in two ways:\n", + "1. Evaluate how much the success rate improves with re-training\n", + "2. The difference between the total number of stable MOFs found w/ and w/o a closed loop" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "e0c564c2-ec6d-4d55-bc8e-944a80d35598", + "metadata": {}, + "outputs": [], + "source": [ + "from itertools import chain\n", + "from scipy.interpolate import interp1d\n", + "from pathlib import Path\n", + "import pandas as pd" + ] + }, + { + "cell_type": "markdown", + "id": "c56bd56d-a264-4e11-a7ec-e3bb62268e42", + "metadata": {}, + "source": [ + "## Route 1: Measure Success Rate by Model Generation" + ] + }, + { + "cell_type": "markdown", + "id": "c8bb1426-af24-4159-b1c4-55e9c203bfde", + "metadata": {}, + "source": [ + "## Round 2: Assess workflow outcomes w/o retraining\n", + "Show that it gets better" + ] + }, + { + "cell_type": "markdown", + "id": "10478466-dfd8-41c0-bf53-20dcb394fb91", + "metadata": {}, + "source": [ + "### Get the \"Stable Found\" at 90 minutes\n", + "Loop over all runs and store: scale, if retrained or not, and the number of stable found after 90 minutes. \n", + "The 450-node run switches how it trained DiffLinker at around 90 minutes, and we don't want to study that effect yet." + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "aa46b598-7a7f-4f13-8a7c-c991ef4e2013", + "metadata": {}, + "outputs": [], + "source": [ + "hours = 1.5" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "98400659-f017-4b68-8ac7-1e9643c10c65", + "metadata": {}, + "outputs": [], + "source": [ + "success_data = []\n", + "for path in chain(Path('summaries').glob('*-nodes.csv.gz'), Path('summaries').glob('*-nodes_repeat-*.csv.gz'), Path('summaries').glob('*no-retrain*.csv.gz')):\n", + " # Get metadata\n", + " count = int(path.name.split(\"-\")[0])\n", + " retrain = 'no-retrain' not in path.name\n", + "\n", + " # Pull the success rate\n", + " mofs = pd.read_csv(path)\n", + " num_found = interp1d(mofs['walltime'], mofs['cumulative_found'], kind='previous')(hours * 3600).item()\n", + "\n", + " success_data.append({\n", + " 'nodes': count,\n", + " 'retrain': retrain,\n", + " 'found': num_found,\n", + " 'found_node-hr': num_found / (count * hours)\n", + " })\n", + "success_data = pd.DataFrame(success_data)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "ac0c533f-f071-4e74-95d1-77fb0a9e17d9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
foundfound_node-hr
nodesretrain
32False133.02.770833
True313.06.520833
64False426.54.442708
True641.06.677083
128True1622.08.447917
256True3633.09.460938
450True6554.09.709630
\n", + "
" + ], + "text/plain": [ + " found found_node-hr\n", + "nodes retrain \n", + "32 False 133.0 2.770833\n", + " True 313.0 6.520833\n", + "64 False 426.5 4.442708\n", + " True 641.0 6.677083\n", + "128 True 1622.0 8.447917\n", + "256 True 3633.0 9.460938\n", + "450 True 6554.0 9.709630" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "success_data.groupby(['nodes', 'retrain']).mean()" + ] + }, + { + "cell_type": "markdown", + "id": "a973a7cb-e161-429e-9707-4f6b6a608bd2", + "metadata": {}, + "source": [ + "TBD: Make a plot" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "269787d5-5319-47f7-83ce-1a807bf14583", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/scripts/evaluate-generated-mofs/figures/stability-over-time-step.png b/scripts/evaluate-generated-mofs/figures/stability-over-time-step.png index c4afb8e1..c3ab0d46 100644 Binary files a/scripts/evaluate-generated-mofs/figures/stability-over-time-step.png and b/scripts/evaluate-generated-mofs/figures/stability-over-time-step.png differ diff --git a/scripts/evaluate-generated-mofs/figures/stability-over-time.png b/scripts/evaluate-generated-mofs/figures/stability-over-time.png index 9339969e..81c1ef5a 100644 Binary files a/scripts/evaluate-generated-mofs/figures/stability-over-time.png and b/scripts/evaluate-generated-mofs/figures/stability-over-time.png differ diff --git a/scripts/evaluate-generated-mofs/figures/stable-found-per-hour.pdf b/scripts/evaluate-generated-mofs/figures/stable-found-per-hour.pdf new file mode 100644 index 00000000..8521d73c Binary files /dev/null and b/scripts/evaluate-generated-mofs/figures/stable-found-per-hour.pdf differ diff --git a/scripts/evaluate-generated-mofs/figures/stable-found-per-hour.png b/scripts/evaluate-generated-mofs/figures/stable-found-per-hour.png new file mode 100644 index 00000000..6d2e0c6b Binary files /dev/null and b/scripts/evaluate-generated-mofs/figures/stable-found-per-hour.png differ diff --git a/scripts/evaluate-generated-mofs/figures/stable-found-per-node-hour.pdf b/scripts/evaluate-generated-mofs/figures/stable-found-per-node-hour.pdf index a56d34cc..b684127d 100644 Binary files a/scripts/evaluate-generated-mofs/figures/stable-found-per-node-hour.pdf and b/scripts/evaluate-generated-mofs/figures/stable-found-per-node-hour.pdf differ diff --git a/scripts/evaluate-generated-mofs/figures/stable-found-per-node-hour.png b/scripts/evaluate-generated-mofs/figures/stable-found-per-node-hour.png index 5c0ee7d8..890b9151 100644 Binary files a/scripts/evaluate-generated-mofs/figures/stable-found-per-node-hour.png and b/scripts/evaluate-generated-mofs/figures/stable-found-per-node-hour.png differ