Skip to content

Commit

Permalink
Run cibuildwheel on Apple silicon to publish pre-compiled macOS ARM…
Browse files Browse the repository at this point in the history
… binaries (#78)

* run cibuildwheel on Apple silicon

* fix typo

* add option to release to TestPyPI

* run CI tests on macos-latest-xlarge

hopefully fixes

RuntimeError: MPS backend out of memory (MPS allocated: 0 bytes, other allocations: 0 bytes, max allowed: 1.70 GB). Tried to allocate 0 bytes on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

* pypa/cibuildwheel set CIBW_ARCHS_MACOS: universal2
  • Loading branch information
janosh authored Oct 6, 2023
1 parent c8f5985 commit 4b7286e
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 119 deletions.
10 changes: 6 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ on:
inputs:
task:
type: choice
options: [tests, release]
options: [tests, release, test-release]
default: tests
description: Only run tests or release a new version to PyPI after tests pass.
description: Run tests, release to PyPI, or release to TestPyPI.

jobs:
tests:
Expand Down Expand Up @@ -86,9 +86,10 @@ jobs:
uses: actions/checkout@v4

- name: Build wheels
uses: pypa/cibuildwheel@v2.15.0
uses: pypa/cibuildwheel@v2.16.2
env:
CIBW_BUILD: cp${{ matrix.python-version }}-*
CIBW_ARCHS_MACOS: universal2

- name: Save artifact
uses: actions/upload-artifact@v3
Expand All @@ -108,8 +109,9 @@ jobs:
name: artifact
path: dist

- name: Publish to PyPi
- name: Publish to PyPi or TestPyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
skip-existing: true
verbose: true
repository-url: ${{ github.event_name == 'workflow_dispatch' && github.event.inputs.task == 'test-release' && 'https://test.pypi.org/legacy/' || '' }}
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.290
rev: v0.0.292
hooks:
- id: ruff
args: [--fix]
Expand All @@ -27,7 +27,7 @@ repos:
- id: trailing-whitespace

- repo: https://github.com/codespell-project/codespell
rev: v2.2.5
rev: v2.2.6
hooks:
- id: codespell
stages: [commit, commit-msg]
Expand All @@ -49,7 +49,7 @@ repos:
- svelte

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v8.49.0
rev: v8.50.0
hooks:
- id: eslint
types: [file]
Expand Down
2 changes: 1 addition & 1 deletion chgnet/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def __init__(

self.keys = []
self.keys = [
(mp_id, graph_id) for mp_id, dic in self.data.items() for graph_id in dic
(mp_id, graph_id) for mp_id, dct in self.data.items() for graph_id in dct
]
random.shuffle(self.keys)
print(f"{len(self.data)} mp_ids, {len(self)} structures imported")
Expand Down
3 changes: 0 additions & 3 deletions chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,6 @@ def __init__(
self.device = use_device
elif torch.cuda.is_available():
self.device = "cuda"
# mps is disabled until stable version of torch for mps is released
# elif torch.backends.mps.is_available():
# self.device = "mps"
else:
self.device = "cpu"
if self.device == "cuda":
Expand Down
202 changes: 98 additions & 104 deletions examples/crystaltoolkit_relax_viewer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,7 @@
"execution_count": null,
"id": "156e8031",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cloning into 'chgnet'...\n",
"remote: Enumerating objects: 50, done.\u001b[K\n",
"remote: Counting objects: 100% (50/50), done.\u001b[K\n",
"remote: Compressing objects: 100% (47/47), done.\u001b[K\n",
"remote: Total 50 (delta 1), reused 17 (delta 0), pack-reused 0\u001b[K\n",
"Receiving objects: 100% (50/50), 4.25 MiB | 2.70 MiB/s, done.\n",
"Resolving deltas: 100% (1/1), done.\n",
"zsh:1: no matches found: ./chgnet[crystal-toolkit]\n"
]
}
],
"outputs": [],
"source": [
"try:\n",
" import chgnet # noqa: F401\n",
Expand Down Expand Up @@ -123,76 +108,102 @@
"output_type": "stream",
"text": [
"CHGNet initialized with 400,438 parameters\n",
"CHGNet will run on cpu\n",
"CHGNet will run on mps\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/janosh/dev/chgnet/chgnet/model/composition_model.py:177: UserWarning: MPS: no support for int64 min/max ops, casting it to int32 (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/mps/operations/ReduceOps.mm:1271.)\n",
" composition_fea = torch.bincount(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" Step Time Energy fmax\n",
"*Force-consistent energies used in optimization.\n",
"FIRE: 0 14:01:10 -51.912251* 27.2278\n",
"FIRE: 1 14:01:10 -54.259518* 12.3964\n",
"FIRE: 2 14:01:10 -54.778671* 8.5672\n",
"FIRE: 3 14:01:11 -55.339821* 5.5388\n",
"FIRE: 4 14:01:11 -55.653206* 7.1592\n",
"FIRE: 5 14:01:11 -56.225849* 6.6752\n",
"FIRE: 6 14:01:11 -56.975388* 4.2375\n",
"FIRE: 7 14:01:11 -57.431259* 4.4837\n",
"FIRE: 8 14:01:11 -57.696171* 5.3055\n",
"FIRE: 9 14:01:11 -57.933193* 3.3038\n",
"FIRE: 10 14:01:11 -57.887894* 6.1535\n",
"FIRE: 11 14:01:11 -57.981998* 4.7339\n",
"FIRE: 12 14:01:12 -58.107471* 3.2390\n",
"FIRE: 13 14:01:12 -58.196518* 2.3609\n",
"FIRE: 14 14:01:12 -58.237015* 2.6211\n",
"FIRE: 15 14:01:12 -58.271477* 3.3198\n",
"FIRE: 16 14:01:12 -58.323418* 3.1086\n",
"FIRE: 17 14:01:12 -58.385509* 2.3118\n",
"FIRE: 18 14:01:12 -58.443253* 1.6311\n",
"FIRE: 19 14:01:12 -58.478363* 2.3351\n",
"FIRE: 20 14:01:12 -58.506485* 2.9915\n",
"FIRE: 21 14:01:13 -58.553890* 2.3584\n",
"FIRE: 22 14:01:13 -58.591610* 1.3229\n",
"FIRE: 23 14:01:13 -58.600597* 2.5641\n",
"FIRE: 24 14:01:13 -58.612297* 2.2842\n",
"FIRE: 25 14:01:13 -58.631424* 1.7666\n",
"FIRE: 26 14:01:13 -58.651585* 1.1141\n",
"FIRE: 27 14:01:14 -58.667530* 1.0229\n",
"FIRE: 28 14:01:14 -58.678112* 1.0305\n",
"FIRE: 29 14:01:14 -58.686382* 1.3254\n",
"FIRE: 30 14:01:14 -58.695862* 1.3551\n",
"FIRE: 31 14:01:14 -58.708542* 1.1167\n",
"FIRE: 32 14:01:14 -58.721897* 0.8152\n",
"FIRE: 33 14:01:14 -58.732655* 0.9976\n",
"FIRE: 34 14:01:14 -58.743347* 1.2707\n",
"FIRE: 35 14:01:14 -58.759544* 1.0066\n",
"FIRE: 36 14:01:15 -58.777481* 0.5213\n",
"FIRE: 37 14:01:15 -58.788338* 1.0805\n",
"FIRE: 38 14:01:15 -58.805302* 1.2441\n",
"FIRE: 39 14:01:15 -58.834023* 0.4629\n",
"FIRE: 40 14:01:15 -58.856319* 1.0140\n",
"FIRE: 41 14:01:15 -58.881325* 0.5240\n",
"FIRE: 42 14:01:15 -58.898716* 1.1566\n",
"FIRE: 43 14:01:15 -58.921391* 0.4211\n",
"FIRE: 44 14:01:15 -58.932678* 0.5672\n",
"FIRE: 45 14:01:15 -58.941673* 0.8468\n",
"FIRE: 46 14:01:16 -58.946671* 0.4729\n",
"FIRE: 47 14:01:16 -58.948338* 0.5703\n",
"FIRE: 48 14:01:16 -58.948872* 0.4885\n",
"FIRE: 49 14:01:16 -58.949738* 0.3384\n",
"FIRE: 50 14:01:16 -58.950603* 0.2155\n",
"FIRE: 51 14:01:16 -58.951290* 0.2596\n",
"FIRE: 52 14:01:16 -58.951885* 0.2818\n",
"FIRE: 53 14:01:16 -58.952572* 0.2948\n",
"FIRE: 54 14:01:17 -58.953487* 0.2832\n",
"FIRE: 55 14:01:17 -58.954651* 0.1820\n",
"FIRE: 56 14:01:17 -58.955776* 0.1377\n",
"FIRE: 57 14:01:17 -58.956646* 0.1858\n",
"FIRE: 58 14:01:17 -58.957542* 0.2483\n",
"FIRE: 59 14:01:17 -58.958771* 0.1507\n",
"FIRE: 60 14:01:17 -58.959930* 0.1098\n",
"FIRE: 61 14:01:17 -58.960972* 0.2491\n",
"FIRE: 62 14:01:17 -58.962578* 0.1265\n",
"FIRE: 63 14:01:17 -58.964127* 0.1622\n",
"FIRE: 64 14:01:18 -58.965885* 0.1447\n",
"FIRE: 65 14:01:18 -58.967422* 0.2064\n",
"FIRE: 66 14:01:18 -58.968880* 0.0730\n"
"FIRE: 0 17:10:46 -39.349243* 93.7997\n",
"FIRE: 1 17:10:46 -53.316616* 10.1811\n",
"FIRE: 2 17:10:46 -53.377773* 15.6013\n",
"FIRE: 3 17:10:47 -54.071163* 11.7506\n",
"FIRE: 4 17:10:47 -54.818359* 5.3271\n",
"FIRE: 5 17:10:48 -55.177044* 7.9573\n",
"FIRE: 6 17:10:48 -55.661133* 9.9137\n",
"FIRE: 7 17:10:49 -56.486736* 6.3212\n",
"FIRE: 8 17:10:49 -57.129395* 4.4612\n",
"FIRE: 9 17:10:50 -57.536762* 6.1146\n",
"FIRE: 10 17:10:50 -57.886269* 2.9151\n",
"FIRE: 11 17:10:51 -57.534672* 9.1977\n",
"FIRE: 12 17:10:52 -57.731918* 6.6787\n",
"FIRE: 13 17:10:52 -57.953892* 3.5996\n",
"FIRE: 14 17:10:53 -58.058907* 2.4661\n",
"FIRE: 15 17:10:53 -58.089428* 4.2591\n",
"FIRE: 16 17:10:53 -58.099186* 4.0957\n",
"FIRE: 17 17:10:54 -58.117531* 3.7736\n",
"FIRE: 18 17:10:54 -58.142269* 3.3059\n",
"FIRE: 19 17:10:54 -58.170528* 2.7055\n",
"FIRE: 20 17:10:55 -58.199223* 1.9944\n",
"FIRE: 21 17:10:55 -58.225574* 1.6872\n",
"FIRE: 22 17:10:56 -58.247837* 1.6420\n",
"FIRE: 23 17:10:56 -58.267609* 1.5754\n",
"FIRE: 24 17:10:57 -58.286213* 1.5527\n",
"FIRE: 25 17:10:58 -58.307659* 1.9544\n",
"FIRE: 26 17:10:58 -58.336651* 2.1167\n",
"FIRE: 27 17:10:58 -58.374733* 1.7955\n",
"FIRE: 28 17:10:59 -58.417065* 1.2746\n",
"FIRE: 29 17:10:59 -58.454224* 1.0842\n",
"FIRE: 30 17:11:00 -58.482494* 1.4909\n",
"FIRE: 31 17:11:00 -58.511620* 1.7936\n",
"FIRE: 32 17:11:00 -58.551266* 1.3864\n",
"FIRE: 33 17:11:00 -58.593399* 0.8655\n",
"FIRE: 34 17:11:01 -58.626717* 1.1029\n",
"FIRE: 35 17:11:01 -58.667667* 1.1856\n",
"FIRE: 36 17:11:02 -58.714378* 0.7611\n",
"FIRE: 37 17:11:02 -58.751740* 1.4115\n",
"FIRE: 38 17:11:02 -58.798595* 0.6277\n",
"FIRE: 39 17:11:03 -58.825634* 1.6683\n",
"FIRE: 40 17:11:04 -58.860550* 0.6463\n",
"FIRE: 41 17:11:04 -58.879448* 1.6940\n",
"FIRE: 42 17:11:05 -58.889172* 1.0395\n",
"FIRE: 43 17:11:05 -58.897785* 0.5511\n",
"FIRE: 44 17:11:06 -58.900936* 0.9203\n",
"FIRE: 45 17:11:06 -58.902317* 0.8079\n",
"FIRE: 46 17:11:06 -58.904602* 0.5999\n",
"FIRE: 47 17:11:06 -58.907150* 0.4713\n",
"FIRE: 48 17:11:07 -58.909378* 0.4035\n",
"FIRE: 49 17:11:07 -58.911129* 0.3782\n",
"FIRE: 50 17:11:07 -58.912685* 0.4900\n",
"FIRE: 51 17:11:08 -58.914524* 0.5925\n",
"FIRE: 52 17:11:08 -58.917168* 0.5787\n",
"FIRE: 53 17:11:09 -58.920650* 0.4262\n",
"FIRE: 54 17:11:09 -58.924351* 0.2559\n",
"FIRE: 55 17:11:10 -58.927425* 0.2542\n",
"FIRE: 56 17:11:10 -58.930111* 0.4618\n",
"FIRE: 57 17:11:10 -58.933582* 0.4244\n",
"FIRE: 58 17:11:11 -58.937565* 0.2129\n",
"FIRE: 59 17:11:11 -58.940548* 0.3162\n",
"FIRE: 60 17:11:11 -58.943970* 0.3788\n",
"FIRE: 61 17:11:11 -58.948582* 0.1709\n",
"FIRE: 62 17:11:12 -58.952435* 0.4052\n",
"FIRE: 63 17:11:12 -58.957687* 0.1760\n",
"FIRE: 64 17:11:13 -58.961376* 0.4583\n",
"FIRE: 65 17:11:13 -58.965767* 0.1113\n",
"FIRE: 66 17:11:13 -58.967800* 0.4162\n",
"FIRE: 67 17:11:13 -58.969425* 0.5236\n",
"FIRE: 68 17:11:14 -58.970615* 0.2232\n",
"FIRE: 69 17:11:14 -58.970974* 0.2027\n",
"FIRE: 70 17:11:14 -58.971069* 0.1704\n",
"FIRE: 71 17:11:14 -58.971218* 0.1290\n",
"FIRE: 72 17:11:14 -58.971375* 0.1228\n",
"FIRE: 73 17:11:15 -58.971523* 0.1153\n",
"FIRE: 74 17:11:15 -58.971649* 0.1073\n",
"FIRE: 75 17:11:15 -58.971767* 0.1409\n",
"FIRE: 76 17:11:15 -58.971931* 0.1412\n",
"FIRE: 77 17:11:15 -58.972164* 0.1046\n",
"FIRE: 78 17:11:15 -58.972416* 0.0777\n"
]
}
],
Expand Down Expand Up @@ -248,30 +259,13 @@
"id": "c9f16422",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/janosh/dev/crystaltoolkit/crystal_toolkit/components/diffraction_tem.py:18: UserWarning: The TEMDiffractionComponent requires the py4DSTEM package.\n",
" warn(\"The TEMDiffractionComponent requires the py4DSTEM package.\")\n",
"/Users/janosh/dev/crystaltoolkit/crystal_toolkit/components/localenv.py:50: UserWarning: Using dscribe SOAP and REMatchKernel requires the dscribe package which was made optional since it in turn requires numba and numba was a common source of installation issues.\n",
" warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"No module named 'phonopy'\n"
]
},
{
"data": {
"text/html": [
"\n",
" <iframe\n",
" width=\"100%\"\n",
" height=\"800\"\n",
" height=\"650\"\n",
" src=\"http://127.0.0.1:8050/\"\n",
" frameborder=\"0\"\n",
" allowfullscreen\n",
Expand All @@ -280,7 +274,7 @@
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x14d031410>"
"<IPython.lib.display.IFrame at 0x3d5938e10>"
]
},
"metadata": {},
Expand Down Expand Up @@ -392,7 +386,7 @@
" return structure, fig\n",
"\n",
"\n",
"app.run(mode=\"inline\", height=800, use_reloader=False)"
"app.run(height=800, use_reloader=False)"
]
}
],
Expand Down
8 changes: 4 additions & 4 deletions examples/fine_tuning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
"CHGNet is interfaced to [Pymatgen](https://pymatgen.org/), the training samples (normally coming from different DFTs like VASP),\n",
"need to be converted to [pymatgen.core.structure](https://pymatgen.org/pymatgen.core.html#module-pymatgen.core.structure).\n",
"\n",
"To convert VASP calculation to pymatgen structures and CHGNet labels, you can use the following [code](https://github.com/CederGroupHub/chgnet/blob/main/chgnet/utils/vasp_utils.py):"
"To convert VASP calculation to pymatgen structures and CHGNet labels, you can use the following [code](https://github.com/CederGroupHub/chgnet/blob/main/chgnet/utils/vasp_utils.py):\n"
]
},
{
Expand All @@ -79,7 +79,7 @@
"\n",
"# ./my_vasp_calc_dir contains vasprun.xml OSZICAR etc.\n",
"dataset_dict = parse_vasp_dir(file_root=\"./my_vasp_calc_dir\")\n",
"print(dataset_dict.keys())"
"print(list(dataset_dict))"
]
},
{
Expand All @@ -93,7 +93,7 @@
"\n",
"For super-large training dataset, like MPtrj dataset, we recommend [converting them to CHGNet graphs](https://github.com/CederGroupHub/chgnet/blob/main/examples/make_graphs.py). This will save significant memory and graph computing time.\n",
"\n",
"Below are the example codes to save the structures."
"Below are the example codes to save the structures.\n"
]
},
{
Expand Down Expand Up @@ -155,7 +155,7 @@
"id": "e1611921",
"metadata": {},
"source": [
"## 1. Prepare Training Data"
"## 1. Prepare Training Data\n"
]
},
{
Expand Down

0 comments on commit 4b7286e

Please sign in to comment.