Skip to content

Commit

Permalink
First notebook for comparing sampling methods
Browse files Browse the repository at this point in the history
  • Loading branch information
WardLT committed Nov 15, 2023
1 parent b65e0d1 commit 62fa41b
Show file tree
Hide file tree
Showing 5 changed files with 692 additions and 24 deletions.
6 changes: 3 additions & 3 deletions notebooks/1_explore-sampling-methods/run-all-methods.sh
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
#! /bin/bash

xyz=../data/exact/caffeine_pm7_None.xyz
for step_size in 0.02 0.01 0.005; do
for step_size in 0.04 0.02 0.01 0.005; do
# Do the randomized methods
for method in 0_random-directions-same-distance.ipynb 1_random-directions-variable-distance.ipynb; do
papermill -p starting_geometry $xyz -p step_size $s $method - > /dev/null
papermill -p starting_geometry $xyz -p step_size $step_size $method last.ipynb
done

# Test with different reductions for "along axes"
notebook=2_displace-along-axes.ipynb
for n in 1 2 4; do
papermill -p starting_geometry $xyz -p perturbs_per_evaluation $n $notebook - > /dev/null
papermill -p starting_geometry $xyz -p perturbs_per_evaluation $n -p step_size $step_size $notebook last.ipynb
done
done
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"id": "ebbbc7f5-3007-420f-861a-9f65f84436be",
"metadata": {
"tags": []
Expand All @@ -29,6 +29,7 @@
"from dscribe.descriptors import MBTR\n",
"from ase.vibrations import VibrationsData\n",
"from ase.db import connect\n",
"from random import sample\n",
"from pathlib import Path\n",
"from tqdm import tqdm\n",
"import numpy as np\n",
Expand All @@ -47,7 +48,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"id": "99bd4c92-9a7b-4e88-ac45-dbf30fbfc9e0",
"metadata": {
"tags": [
Expand All @@ -56,8 +57,9 @@
},
"outputs": [],
"source": [
"db_path = '../1_explore-sampling-methods/data/along-axes/caffeine_pm7_None_d=5.00e-03-N=2.db'\n",
"overwrite = False"
"db_path: str = '../1_explore-sampling-methods/data/along-axes/caffeine_pm7_None_d=5.00e-03-N=2.db'\n",
"overwrite: bool = False\n",
"max_size: int = 10000"
]
},
{
Expand All @@ -70,7 +72,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"id": "a8be3c37-bf1f-4ba4-ba8f-afff6d6bed7d",
"metadata": {
"tags": []
Expand All @@ -94,24 +96,12 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": null,
"id": "75d22086-f020-40d7-8327-1154491b9821",
"metadata": {
"tags": []
},
"outputs": [
{
"ename": "ValueError",
"evalue": "Already done!",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[12], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (out_dir \u001b[38;5;241m/\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mout_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.json\u001b[39m\u001b[38;5;124m'\u001b[39m)\u001b[38;5;241m.\u001b[39mexists() \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m overwrite:\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mAlready done!\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
"\u001b[0;31mValueError\u001b[0m: Already done!"
]
}
],
"outputs": [],
"source": [
"if (out_dir / f'{out_name}-full.json').exists() and not overwrite:\n",
" raise ValueError('Already done!')"
Expand Down Expand Up @@ -140,6 +130,28 @@
"print(f'Loaded {len(data)} structures')"
]
},
{
"cell_type": "markdown",
"id": "0c8aae57-1863-4bad-a56b-31f7b8a6062b",
"metadata": {},
"source": [
"Downsample if desired"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e2dfe036-a173-41ff-817b-2e92349b9704",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"if max_size is not None and len(data) > max_size:\n",
" data = sample(data, max_size)\n",
" print(f'Downselected to {len(data)}')"
]
},
{
"cell_type": "markdown",
"id": "cb1a8e03-b045-49a4-95fd-61636a48fbad",
Expand Down
4 changes: 2 additions & 2 deletions notebooks/2_testing-fitting-strategies/run-all-dbs.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#! /bin/bash

notebook=$1
dbs=$(find ../1_explore-sampling-methods/data/ -name "caffine_pm7_None*.db")
dbs=$(find ../1_explore-sampling-methods/data/ -name "caffeine_pm7_None*.db")
for db in $dbs; do
echo $db
papermill -p db_path $db $notebook -
papermill -p db_path "$db" -p max_size 5000 $notebook last.html
done
Loading

0 comments on commit 62fa41b

Please sign in to comment.