Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run cibuildwheel on Apple silicon to publish pre-compiled macOS ARM binaries #78

Merged
merged 5 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ on:
inputs:
task:
type: choice
options: [tests, release]
options: [tests, release, test-release]
default: tests
description: Only run tests or release a new version to PyPI after tests pass.
description: Run tests, release to PyPI, or release to TestPyPI.

jobs:
tests:
Expand Down Expand Up @@ -86,9 +86,10 @@ jobs:
uses: actions/checkout@v4

- name: Build wheels
uses: pypa/cibuildwheel@v2.15.0
uses: pypa/cibuildwheel@v2.16.2
env:
CIBW_BUILD: cp${{ matrix.python-version }}-*
CIBW_ARCHS_MACOS: universal2

- name: Save artifact
uses: actions/upload-artifact@v3
Expand All @@ -108,8 +109,9 @@ jobs:
name: artifact
path: dist

- name: Publish to PyPi
- name: Publish to PyPi or TestPyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
skip-existing: true
verbose: true
repository-url: ${{ github.event_name == 'workflow_dispatch' && github.event.inputs.task == 'test-release' && 'https://test.pypi.org/legacy/' || '' }}
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.290
rev: v0.0.292
hooks:
- id: ruff
args: [--fix]
Expand All @@ -27,7 +27,7 @@ repos:
- id: trailing-whitespace

- repo: https://github.com/codespell-project/codespell
rev: v2.2.5
rev: v2.2.6
hooks:
- id: codespell
stages: [commit, commit-msg]
Expand All @@ -49,7 +49,7 @@ repos:
- svelte

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v8.49.0
rev: v8.50.0
hooks:
- id: eslint
types: [file]
Expand Down
2 changes: 1 addition & 1 deletion chgnet/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def __init__(

self.keys = []
self.keys = [
(mp_id, graph_id) for mp_id, dic in self.data.items() for graph_id in dic
(mp_id, graph_id) for mp_id, dct in self.data.items() for graph_id in dct
]
random.shuffle(self.keys)
print(f"{len(self.data)} mp_ids, {len(self)} structures imported")
Expand Down
3 changes: 0 additions & 3 deletions chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,6 @@ def __init__(
self.device = use_device
elif torch.cuda.is_available():
self.device = "cuda"
# mps is disabled until stable version of torch for mps is released
# elif torch.backends.mps.is_available():
# self.device = "mps"
else:
self.device = "cpu"
if self.device == "cuda":
Expand Down
202 changes: 98 additions & 104 deletions examples/crystaltoolkit_relax_viewer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,7 @@
"execution_count": null,
"id": "156e8031",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cloning into 'chgnet'...\n",
"remote: Enumerating objects: 50, done.\u001b[K\n",
"remote: Counting objects: 100% (50/50), done.\u001b[K\n",
"remote: Compressing objects: 100% (47/47), done.\u001b[K\n",
"remote: Total 50 (delta 1), reused 17 (delta 0), pack-reused 0\u001b[K\n",
"Receiving objects: 100% (50/50), 4.25 MiB | 2.70 MiB/s, done.\n",
"Resolving deltas: 100% (1/1), done.\n",
"zsh:1: no matches found: ./chgnet[crystal-toolkit]\n"
]
}
],
"outputs": [],
"source": [
"try:\n",
" import chgnet # noqa: F401\n",
Expand Down Expand Up @@ -123,76 +108,102 @@
"output_type": "stream",
"text": [
"CHGNet initialized with 400,438 parameters\n",
"CHGNet will run on cpu\n",
"CHGNet will run on mps\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/janosh/dev/chgnet/chgnet/model/composition_model.py:177: UserWarning: MPS: no support for int64 min/max ops, casting it to int32 (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/mps/operations/ReduceOps.mm:1271.)\n",
" composition_fea = torch.bincount(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" Step Time Energy fmax\n",
"*Force-consistent energies used in optimization.\n",
"FIRE: 0 14:01:10 -51.912251* 27.2278\n",
"FIRE: 1 14:01:10 -54.259518* 12.3964\n",
"FIRE: 2 14:01:10 -54.778671* 8.5672\n",
"FIRE: 3 14:01:11 -55.339821* 5.5388\n",
"FIRE: 4 14:01:11 -55.653206* 7.1592\n",
"FIRE: 5 14:01:11 -56.225849* 6.6752\n",
"FIRE: 6 14:01:11 -56.975388* 4.2375\n",
"FIRE: 7 14:01:11 -57.431259* 4.4837\n",
"FIRE: 8 14:01:11 -57.696171* 5.3055\n",
"FIRE: 9 14:01:11 -57.933193* 3.3038\n",
"FIRE: 10 14:01:11 -57.887894* 6.1535\n",
"FIRE: 11 14:01:11 -57.981998* 4.7339\n",
"FIRE: 12 14:01:12 -58.107471* 3.2390\n",
"FIRE: 13 14:01:12 -58.196518* 2.3609\n",
"FIRE: 14 14:01:12 -58.237015* 2.6211\n",
"FIRE: 15 14:01:12 -58.271477* 3.3198\n",
"FIRE: 16 14:01:12 -58.323418* 3.1086\n",
"FIRE: 17 14:01:12 -58.385509* 2.3118\n",
"FIRE: 18 14:01:12 -58.443253* 1.6311\n",
"FIRE: 19 14:01:12 -58.478363* 2.3351\n",
"FIRE: 20 14:01:12 -58.506485* 2.9915\n",
"FIRE: 21 14:01:13 -58.553890* 2.3584\n",
"FIRE: 22 14:01:13 -58.591610* 1.3229\n",
"FIRE: 23 14:01:13 -58.600597* 2.5641\n",
"FIRE: 24 14:01:13 -58.612297* 2.2842\n",
"FIRE: 25 14:01:13 -58.631424* 1.7666\n",
"FIRE: 26 14:01:13 -58.651585* 1.1141\n",
"FIRE: 27 14:01:14 -58.667530* 1.0229\n",
"FIRE: 28 14:01:14 -58.678112* 1.0305\n",
"FIRE: 29 14:01:14 -58.686382* 1.3254\n",
"FIRE: 30 14:01:14 -58.695862* 1.3551\n",
"FIRE: 31 14:01:14 -58.708542* 1.1167\n",
"FIRE: 32 14:01:14 -58.721897* 0.8152\n",
"FIRE: 33 14:01:14 -58.732655* 0.9976\n",
"FIRE: 34 14:01:14 -58.743347* 1.2707\n",
"FIRE: 35 14:01:14 -58.759544* 1.0066\n",
"FIRE: 36 14:01:15 -58.777481* 0.5213\n",
"FIRE: 37 14:01:15 -58.788338* 1.0805\n",
"FIRE: 38 14:01:15 -58.805302* 1.2441\n",
"FIRE: 39 14:01:15 -58.834023* 0.4629\n",
"FIRE: 40 14:01:15 -58.856319* 1.0140\n",
"FIRE: 41 14:01:15 -58.881325* 0.5240\n",
"FIRE: 42 14:01:15 -58.898716* 1.1566\n",
"FIRE: 43 14:01:15 -58.921391* 0.4211\n",
"FIRE: 44 14:01:15 -58.932678* 0.5672\n",
"FIRE: 45 14:01:15 -58.941673* 0.8468\n",
"FIRE: 46 14:01:16 -58.946671* 0.4729\n",
"FIRE: 47 14:01:16 -58.948338* 0.5703\n",
"FIRE: 48 14:01:16 -58.948872* 0.4885\n",
"FIRE: 49 14:01:16 -58.949738* 0.3384\n",
"FIRE: 50 14:01:16 -58.950603* 0.2155\n",
"FIRE: 51 14:01:16 -58.951290* 0.2596\n",
"FIRE: 52 14:01:16 -58.951885* 0.2818\n",
"FIRE: 53 14:01:16 -58.952572* 0.2948\n",
"FIRE: 54 14:01:17 -58.953487* 0.2832\n",
"FIRE: 55 14:01:17 -58.954651* 0.1820\n",
"FIRE: 56 14:01:17 -58.955776* 0.1377\n",
"FIRE: 57 14:01:17 -58.956646* 0.1858\n",
"FIRE: 58 14:01:17 -58.957542* 0.2483\n",
"FIRE: 59 14:01:17 -58.958771* 0.1507\n",
"FIRE: 60 14:01:17 -58.959930* 0.1098\n",
"FIRE: 61 14:01:17 -58.960972* 0.2491\n",
"FIRE: 62 14:01:17 -58.962578* 0.1265\n",
"FIRE: 63 14:01:17 -58.964127* 0.1622\n",
"FIRE: 64 14:01:18 -58.965885* 0.1447\n",
"FIRE: 65 14:01:18 -58.967422* 0.2064\n",
"FIRE: 66 14:01:18 -58.968880* 0.0730\n"
"FIRE: 0 17:10:46 -39.349243* 93.7997\n",
"FIRE: 1 17:10:46 -53.316616* 10.1811\n",
"FIRE: 2 17:10:46 -53.377773* 15.6013\n",
"FIRE: 3 17:10:47 -54.071163* 11.7506\n",
"FIRE: 4 17:10:47 -54.818359* 5.3271\n",
"FIRE: 5 17:10:48 -55.177044* 7.9573\n",
"FIRE: 6 17:10:48 -55.661133* 9.9137\n",
"FIRE: 7 17:10:49 -56.486736* 6.3212\n",
"FIRE: 8 17:10:49 -57.129395* 4.4612\n",
"FIRE: 9 17:10:50 -57.536762* 6.1146\n",
"FIRE: 10 17:10:50 -57.886269* 2.9151\n",
"FIRE: 11 17:10:51 -57.534672* 9.1977\n",
"FIRE: 12 17:10:52 -57.731918* 6.6787\n",
"FIRE: 13 17:10:52 -57.953892* 3.5996\n",
"FIRE: 14 17:10:53 -58.058907* 2.4661\n",
"FIRE: 15 17:10:53 -58.089428* 4.2591\n",
"FIRE: 16 17:10:53 -58.099186* 4.0957\n",
"FIRE: 17 17:10:54 -58.117531* 3.7736\n",
"FIRE: 18 17:10:54 -58.142269* 3.3059\n",
"FIRE: 19 17:10:54 -58.170528* 2.7055\n",
"FIRE: 20 17:10:55 -58.199223* 1.9944\n",
"FIRE: 21 17:10:55 -58.225574* 1.6872\n",
"FIRE: 22 17:10:56 -58.247837* 1.6420\n",
"FIRE: 23 17:10:56 -58.267609* 1.5754\n",
"FIRE: 24 17:10:57 -58.286213* 1.5527\n",
"FIRE: 25 17:10:58 -58.307659* 1.9544\n",
"FIRE: 26 17:10:58 -58.336651* 2.1167\n",
"FIRE: 27 17:10:58 -58.374733* 1.7955\n",
"FIRE: 28 17:10:59 -58.417065* 1.2746\n",
"FIRE: 29 17:10:59 -58.454224* 1.0842\n",
"FIRE: 30 17:11:00 -58.482494* 1.4909\n",
"FIRE: 31 17:11:00 -58.511620* 1.7936\n",
"FIRE: 32 17:11:00 -58.551266* 1.3864\n",
"FIRE: 33 17:11:00 -58.593399* 0.8655\n",
"FIRE: 34 17:11:01 -58.626717* 1.1029\n",
"FIRE: 35 17:11:01 -58.667667* 1.1856\n",
"FIRE: 36 17:11:02 -58.714378* 0.7611\n",
"FIRE: 37 17:11:02 -58.751740* 1.4115\n",
"FIRE: 38 17:11:02 -58.798595* 0.6277\n",
"FIRE: 39 17:11:03 -58.825634* 1.6683\n",
"FIRE: 40 17:11:04 -58.860550* 0.6463\n",
"FIRE: 41 17:11:04 -58.879448* 1.6940\n",
"FIRE: 42 17:11:05 -58.889172* 1.0395\n",
"FIRE: 43 17:11:05 -58.897785* 0.5511\n",
"FIRE: 44 17:11:06 -58.900936* 0.9203\n",
"FIRE: 45 17:11:06 -58.902317* 0.8079\n",
"FIRE: 46 17:11:06 -58.904602* 0.5999\n",
"FIRE: 47 17:11:06 -58.907150* 0.4713\n",
"FIRE: 48 17:11:07 -58.909378* 0.4035\n",
"FIRE: 49 17:11:07 -58.911129* 0.3782\n",
"FIRE: 50 17:11:07 -58.912685* 0.4900\n",
"FIRE: 51 17:11:08 -58.914524* 0.5925\n",
"FIRE: 52 17:11:08 -58.917168* 0.5787\n",
"FIRE: 53 17:11:09 -58.920650* 0.4262\n",
"FIRE: 54 17:11:09 -58.924351* 0.2559\n",
"FIRE: 55 17:11:10 -58.927425* 0.2542\n",
"FIRE: 56 17:11:10 -58.930111* 0.4618\n",
"FIRE: 57 17:11:10 -58.933582* 0.4244\n",
"FIRE: 58 17:11:11 -58.937565* 0.2129\n",
"FIRE: 59 17:11:11 -58.940548* 0.3162\n",
"FIRE: 60 17:11:11 -58.943970* 0.3788\n",
"FIRE: 61 17:11:11 -58.948582* 0.1709\n",
"FIRE: 62 17:11:12 -58.952435* 0.4052\n",
"FIRE: 63 17:11:12 -58.957687* 0.1760\n",
"FIRE: 64 17:11:13 -58.961376* 0.4583\n",
"FIRE: 65 17:11:13 -58.965767* 0.1113\n",
"FIRE: 66 17:11:13 -58.967800* 0.4162\n",
"FIRE: 67 17:11:13 -58.969425* 0.5236\n",
"FIRE: 68 17:11:14 -58.970615* 0.2232\n",
"FIRE: 69 17:11:14 -58.970974* 0.2027\n",
"FIRE: 70 17:11:14 -58.971069* 0.1704\n",
"FIRE: 71 17:11:14 -58.971218* 0.1290\n",
"FIRE: 72 17:11:14 -58.971375* 0.1228\n",
"FIRE: 73 17:11:15 -58.971523* 0.1153\n",
"FIRE: 74 17:11:15 -58.971649* 0.1073\n",
"FIRE: 75 17:11:15 -58.971767* 0.1409\n",
"FIRE: 76 17:11:15 -58.971931* 0.1412\n",
"FIRE: 77 17:11:15 -58.972164* 0.1046\n",
"FIRE: 78 17:11:15 -58.972416* 0.0777\n"
]
}
],
Expand Down Expand Up @@ -248,30 +259,13 @@
"id": "c9f16422",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/janosh/dev/crystaltoolkit/crystal_toolkit/components/diffraction_tem.py:18: UserWarning: The TEMDiffractionComponent requires the py4DSTEM package.\n",
" warn(\"The TEMDiffractionComponent requires the py4DSTEM package.\")\n",
"/Users/janosh/dev/crystaltoolkit/crystal_toolkit/components/localenv.py:50: UserWarning: Using dscribe SOAP and REMatchKernel requires the dscribe package which was made optional since it in turn requires numba and numba was a common source of installation issues.\n",
" warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"No module named 'phonopy'\n"
]
},
{
"data": {
"text/html": [
"\n",
" <iframe\n",
" width=\"100%\"\n",
" height=\"800\"\n",
" height=\"650\"\n",
" src=\"http://127.0.0.1:8050/\"\n",
" frameborder=\"0\"\n",
" allowfullscreen\n",
Expand All @@ -280,7 +274,7 @@
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x14d031410>"
"<IPython.lib.display.IFrame at 0x3d5938e10>"
]
},
"metadata": {},
Expand Down Expand Up @@ -392,7 +386,7 @@
" return structure, fig\n",
"\n",
"\n",
"app.run(mode=\"inline\", height=800, use_reloader=False)"
"app.run(height=800, use_reloader=False)"
]
}
],
Expand Down
8 changes: 4 additions & 4 deletions examples/fine_tuning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
"CHGNet is interfaced to [Pymatgen](https://pymatgen.org/), the training samples (normally coming from different DFTs like VASP),\n",
"need to be converted to [pymatgen.core.structure](https://pymatgen.org/pymatgen.core.html#module-pymatgen.core.structure).\n",
"\n",
"To convert VASP calculation to pymatgen structures and CHGNet labels, you can use the following [code](https://github.com/CederGroupHub/chgnet/blob/main/chgnet/utils/vasp_utils.py):"
"To convert VASP calculation to pymatgen structures and CHGNet labels, you can use the following [code](https://github.com/CederGroupHub/chgnet/blob/main/chgnet/utils/vasp_utils.py):\n"
]
},
{
Expand All @@ -79,7 +79,7 @@
"\n",
"# ./my_vasp_calc_dir contains vasprun.xml OSZICAR etc.\n",
"dataset_dict = parse_vasp_dir(file_root=\"./my_vasp_calc_dir\")\n",
"print(dataset_dict.keys())"
"print(list(dataset_dict))"
]
},
{
Expand All @@ -93,7 +93,7 @@
"\n",
"For super-large training dataset, like MPtrj dataset, we recommend [converting them to CHGNet graphs](https://github.com/CederGroupHub/chgnet/blob/main/examples/make_graphs.py). This will save significant memory and graph computing time.\n",
"\n",
"Below are the example codes to save the structures."
"Below are the example codes to save the structures.\n"
]
},
{
Expand Down Expand Up @@ -155,7 +155,7 @@
"id": "e1611921",
"metadata": {},
"source": [
"## 1. Prepare Training Data"
"## 1. Prepare Training Data\n"
]
},
{
Expand Down