From 1d20f5042b570895650b897ff5133dc55ac4bcba Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 4 Oct 2023 15:49:29 -0700 Subject: [PATCH 1/5] run cibuildwheel on Apple silicon --- .github/workflows/test.yml | 4 ++-- .pre-commit-config.yaml | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f4657f93..b1df8292 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -78,7 +78,7 @@ jobs: if: github.event_name == 'release' || (github.event_name == 'workflow_dispatch' && inputs.task == 'release') strategy: matrix: - os: [ubuntu-latest, macos-latest, windows-latest] + os: [ubuntu-latest, macos-latest, macos-latest-xlarge, windows-latest] python-version: ["39", "310", "311"] runs-on: ${{ matrix.os }} steps: @@ -86,7 +86,7 @@ jobs: uses: actions/checkout@v4 - name: Build wheels - uses: pypa/cibuildwheel@v2.15.0 + uses: pypa/cibuildwheel@v2.16.2 env: CIBW_BUILD: cp${{ matrix.python-version }}-* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d3ccb12d..32d5d0c0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg] repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.290 + rev: v0.0.292 hooks: - id: ruff args: [--fix] @@ -27,7 +27,7 @@ repos: - id: trailing-whitespace - repo: https://github.com/codespell-project/codespell - rev: v2.2.5 + rev: v2.2.6 hooks: - id: codespell stages: [commit, commit-msg] @@ -49,7 +49,7 @@ repos: - svelte - repo: https://github.com/pre-commit/mirrors-eslint - rev: v8.49.0 + rev: v8.50.0 hooks: - id: eslint types: [file] From ed9181b96c619a9756f0d437ee273c6c2a7ceab2 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 4 Oct 2023 18:32:53 -0700 Subject: [PATCH 2/5] fix typo --- chgnet/data/dataset.py | 2 +- chgnet/trainer/trainer.py | 3 - examples/crystaltoolkit_relax_viewer.ipynb | 202 ++++++++++----------- examples/fine_tuning.ipynb | 8 +- 4 files changed, 103 insertions(+), 112 deletions(-) diff --git a/chgnet/data/dataset.py b/chgnet/data/dataset.py index c58036e0..1d8ba35b 100644 --- a/chgnet/data/dataset.py +++ b/chgnet/data/dataset.py @@ -520,7 +520,7 @@ def __init__( self.keys = [] self.keys = [ - (mp_id, graph_id) for mp_id, dic in self.data.items() for graph_id in dic + (mp_id, graph_id) for mp_id, dct in self.data.items() for graph_id in dct ] random.shuffle(self.keys) print(f"{len(self.data)} mp_ids, {len(self)} structures imported") diff --git a/chgnet/trainer/trainer.py b/chgnet/trainer/trainer.py index 54099c33..2b46b6af 100644 --- a/chgnet/trainer/trainer.py +++ b/chgnet/trainer/trainer.py @@ -177,9 +177,6 @@ def __init__( self.device = use_device elif torch.cuda.is_available(): self.device = "cuda" - # mps is disabled until stable version of torch for mps is released - # elif torch.backends.mps.is_available(): - # self.device = "mps" else: self.device = "cpu" if self.device == "cuda": diff --git a/examples/crystaltoolkit_relax_viewer.ipynb b/examples/crystaltoolkit_relax_viewer.ipynb index 0569eb9e..fc1cd187 100644 --- a/examples/crystaltoolkit_relax_viewer.ipynb +++ b/examples/crystaltoolkit_relax_viewer.ipynb @@ -27,22 +27,7 @@ "execution_count": null, "id": "156e8031", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Cloning into 'chgnet'...\n", - "remote: Enumerating objects: 50, done.\u001b[K\n", - "remote: Counting objects: 100% (50/50), done.\u001b[K\n", - "remote: Compressing objects: 100% (47/47), done.\u001b[K\n", - "remote: Total 50 (delta 1), reused 17 (delta 0), pack-reused 0\u001b[K\n", - "Receiving objects: 100% (50/50), 4.25 MiB | 2.70 MiB/s, done.\n", - "Resolving deltas: 100% (1/1), done.\n", - "zsh:1: no matches found: ./chgnet[crystal-toolkit]\n" - ] - } - ], + "outputs": [], "source": [ "try:\n", " import chgnet # noqa: F401\n", @@ -123,76 +108,102 @@ "output_type": "stream", "text": [ "CHGNet initialized with 400,438 parameters\n", - "CHGNet will run on cpu\n", + "CHGNet will run on mps\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/janosh/dev/chgnet/chgnet/model/composition_model.py:177: UserWarning: MPS: no support for int64 min/max ops, casting it to int32 (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/mps/operations/ReduceOps.mm:1271.)\n", + " composition_fea = torch.bincount(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ " Step Time Energy fmax\n", "*Force-consistent energies used in optimization.\n", - "FIRE: 0 14:01:10 -51.912251* 27.2278\n", - "FIRE: 1 14:01:10 -54.259518* 12.3964\n", - "FIRE: 2 14:01:10 -54.778671* 8.5672\n", - "FIRE: 3 14:01:11 -55.339821* 5.5388\n", - "FIRE: 4 14:01:11 -55.653206* 7.1592\n", - "FIRE: 5 14:01:11 -56.225849* 6.6752\n", - "FIRE: 6 14:01:11 -56.975388* 4.2375\n", - "FIRE: 7 14:01:11 -57.431259* 4.4837\n", - "FIRE: 8 14:01:11 -57.696171* 5.3055\n", - "FIRE: 9 14:01:11 -57.933193* 3.3038\n", - "FIRE: 10 14:01:11 -57.887894* 6.1535\n", - "FIRE: 11 14:01:11 -57.981998* 4.7339\n", - "FIRE: 12 14:01:12 -58.107471* 3.2390\n", - "FIRE: 13 14:01:12 -58.196518* 2.3609\n", - "FIRE: 14 14:01:12 -58.237015* 2.6211\n", - "FIRE: 15 14:01:12 -58.271477* 3.3198\n", - "FIRE: 16 14:01:12 -58.323418* 3.1086\n", - "FIRE: 17 14:01:12 -58.385509* 2.3118\n", - "FIRE: 18 14:01:12 -58.443253* 1.6311\n", - "FIRE: 19 14:01:12 -58.478363* 2.3351\n", - "FIRE: 20 14:01:12 -58.506485* 2.9915\n", - "FIRE: 21 14:01:13 -58.553890* 2.3584\n", - "FIRE: 22 14:01:13 -58.591610* 1.3229\n", - "FIRE: 23 14:01:13 -58.600597* 2.5641\n", - "FIRE: 24 14:01:13 -58.612297* 2.2842\n", - "FIRE: 25 14:01:13 -58.631424* 1.7666\n", - "FIRE: 26 14:01:13 -58.651585* 1.1141\n", - "FIRE: 27 14:01:14 -58.667530* 1.0229\n", - "FIRE: 28 14:01:14 -58.678112* 1.0305\n", - "FIRE: 29 14:01:14 -58.686382* 1.3254\n", - "FIRE: 30 14:01:14 -58.695862* 1.3551\n", - "FIRE: 31 14:01:14 -58.708542* 1.1167\n", - "FIRE: 32 14:01:14 -58.721897* 0.8152\n", - "FIRE: 33 14:01:14 -58.732655* 0.9976\n", - "FIRE: 34 14:01:14 -58.743347* 1.2707\n", - "FIRE: 35 14:01:14 -58.759544* 1.0066\n", - "FIRE: 36 14:01:15 -58.777481* 0.5213\n", - "FIRE: 37 14:01:15 -58.788338* 1.0805\n", - "FIRE: 38 14:01:15 -58.805302* 1.2441\n", - "FIRE: 39 14:01:15 -58.834023* 0.4629\n", - "FIRE: 40 14:01:15 -58.856319* 1.0140\n", - "FIRE: 41 14:01:15 -58.881325* 0.5240\n", - "FIRE: 42 14:01:15 -58.898716* 1.1566\n", - "FIRE: 43 14:01:15 -58.921391* 0.4211\n", - "FIRE: 44 14:01:15 -58.932678* 0.5672\n", - "FIRE: 45 14:01:15 -58.941673* 0.8468\n", - "FIRE: 46 14:01:16 -58.946671* 0.4729\n", - "FIRE: 47 14:01:16 -58.948338* 0.5703\n", - "FIRE: 48 14:01:16 -58.948872* 0.4885\n", - "FIRE: 49 14:01:16 -58.949738* 0.3384\n", - "FIRE: 50 14:01:16 -58.950603* 0.2155\n", - "FIRE: 51 14:01:16 -58.951290* 0.2596\n", - "FIRE: 52 14:01:16 -58.951885* 0.2818\n", - "FIRE: 53 14:01:16 -58.952572* 0.2948\n", - "FIRE: 54 14:01:17 -58.953487* 0.2832\n", - "FIRE: 55 14:01:17 -58.954651* 0.1820\n", - "FIRE: 56 14:01:17 -58.955776* 0.1377\n", - "FIRE: 57 14:01:17 -58.956646* 0.1858\n", - "FIRE: 58 14:01:17 -58.957542* 0.2483\n", - "FIRE: 59 14:01:17 -58.958771* 0.1507\n", - "FIRE: 60 14:01:17 -58.959930* 0.1098\n", - "FIRE: 61 14:01:17 -58.960972* 0.2491\n", - "FIRE: 62 14:01:17 -58.962578* 0.1265\n", - "FIRE: 63 14:01:17 -58.964127* 0.1622\n", - "FIRE: 64 14:01:18 -58.965885* 0.1447\n", - "FIRE: 65 14:01:18 -58.967422* 0.2064\n", - "FIRE: 66 14:01:18 -58.968880* 0.0730\n" + "FIRE: 0 17:10:46 -39.349243* 93.7997\n", + "FIRE: 1 17:10:46 -53.316616* 10.1811\n", + "FIRE: 2 17:10:46 -53.377773* 15.6013\n", + "FIRE: 3 17:10:47 -54.071163* 11.7506\n", + "FIRE: 4 17:10:47 -54.818359* 5.3271\n", + "FIRE: 5 17:10:48 -55.177044* 7.9573\n", + "FIRE: 6 17:10:48 -55.661133* 9.9137\n", + "FIRE: 7 17:10:49 -56.486736* 6.3212\n", + "FIRE: 8 17:10:49 -57.129395* 4.4612\n", + "FIRE: 9 17:10:50 -57.536762* 6.1146\n", + "FIRE: 10 17:10:50 -57.886269* 2.9151\n", + "FIRE: 11 17:10:51 -57.534672* 9.1977\n", + "FIRE: 12 17:10:52 -57.731918* 6.6787\n", + "FIRE: 13 17:10:52 -57.953892* 3.5996\n", + "FIRE: 14 17:10:53 -58.058907* 2.4661\n", + "FIRE: 15 17:10:53 -58.089428* 4.2591\n", + "FIRE: 16 17:10:53 -58.099186* 4.0957\n", + "FIRE: 17 17:10:54 -58.117531* 3.7736\n", + "FIRE: 18 17:10:54 -58.142269* 3.3059\n", + "FIRE: 19 17:10:54 -58.170528* 2.7055\n", + "FIRE: 20 17:10:55 -58.199223* 1.9944\n", + "FIRE: 21 17:10:55 -58.225574* 1.6872\n", + "FIRE: 22 17:10:56 -58.247837* 1.6420\n", + "FIRE: 23 17:10:56 -58.267609* 1.5754\n", + "FIRE: 24 17:10:57 -58.286213* 1.5527\n", + "FIRE: 25 17:10:58 -58.307659* 1.9544\n", + "FIRE: 26 17:10:58 -58.336651* 2.1167\n", + "FIRE: 27 17:10:58 -58.374733* 1.7955\n", + "FIRE: 28 17:10:59 -58.417065* 1.2746\n", + "FIRE: 29 17:10:59 -58.454224* 1.0842\n", + "FIRE: 30 17:11:00 -58.482494* 1.4909\n", + "FIRE: 31 17:11:00 -58.511620* 1.7936\n", + "FIRE: 32 17:11:00 -58.551266* 1.3864\n", + "FIRE: 33 17:11:00 -58.593399* 0.8655\n", + "FIRE: 34 17:11:01 -58.626717* 1.1029\n", + "FIRE: 35 17:11:01 -58.667667* 1.1856\n", + "FIRE: 36 17:11:02 -58.714378* 0.7611\n", + "FIRE: 37 17:11:02 -58.751740* 1.4115\n", + "FIRE: 38 17:11:02 -58.798595* 0.6277\n", + "FIRE: 39 17:11:03 -58.825634* 1.6683\n", + "FIRE: 40 17:11:04 -58.860550* 0.6463\n", + "FIRE: 41 17:11:04 -58.879448* 1.6940\n", + "FIRE: 42 17:11:05 -58.889172* 1.0395\n", + "FIRE: 43 17:11:05 -58.897785* 0.5511\n", + "FIRE: 44 17:11:06 -58.900936* 0.9203\n", + "FIRE: 45 17:11:06 -58.902317* 0.8079\n", + "FIRE: 46 17:11:06 -58.904602* 0.5999\n", + "FIRE: 47 17:11:06 -58.907150* 0.4713\n", + "FIRE: 48 17:11:07 -58.909378* 0.4035\n", + "FIRE: 49 17:11:07 -58.911129* 0.3782\n", + "FIRE: 50 17:11:07 -58.912685* 0.4900\n", + "FIRE: 51 17:11:08 -58.914524* 0.5925\n", + "FIRE: 52 17:11:08 -58.917168* 0.5787\n", + "FIRE: 53 17:11:09 -58.920650* 0.4262\n", + "FIRE: 54 17:11:09 -58.924351* 0.2559\n", + "FIRE: 55 17:11:10 -58.927425* 0.2542\n", + "FIRE: 56 17:11:10 -58.930111* 0.4618\n", + "FIRE: 57 17:11:10 -58.933582* 0.4244\n", + "FIRE: 58 17:11:11 -58.937565* 0.2129\n", + "FIRE: 59 17:11:11 -58.940548* 0.3162\n", + "FIRE: 60 17:11:11 -58.943970* 0.3788\n", + "FIRE: 61 17:11:11 -58.948582* 0.1709\n", + "FIRE: 62 17:11:12 -58.952435* 0.4052\n", + "FIRE: 63 17:11:12 -58.957687* 0.1760\n", + "FIRE: 64 17:11:13 -58.961376* 0.4583\n", + "FIRE: 65 17:11:13 -58.965767* 0.1113\n", + "FIRE: 66 17:11:13 -58.967800* 0.4162\n", + "FIRE: 67 17:11:13 -58.969425* 0.5236\n", + "FIRE: 68 17:11:14 -58.970615* 0.2232\n", + "FIRE: 69 17:11:14 -58.970974* 0.2027\n", + "FIRE: 70 17:11:14 -58.971069* 0.1704\n", + "FIRE: 71 17:11:14 -58.971218* 0.1290\n", + "FIRE: 72 17:11:14 -58.971375* 0.1228\n", + "FIRE: 73 17:11:15 -58.971523* 0.1153\n", + "FIRE: 74 17:11:15 -58.971649* 0.1073\n", + "FIRE: 75 17:11:15 -58.971767* 0.1409\n", + "FIRE: 76 17:11:15 -58.971931* 0.1412\n", + "FIRE: 77 17:11:15 -58.972164* 0.1046\n", + "FIRE: 78 17:11:15 -58.972416* 0.0777\n" ] } ], @@ -248,30 +259,13 @@ "id": "c9f16422", "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/janosh/dev/crystaltoolkit/crystal_toolkit/components/diffraction_tem.py:18: UserWarning: The TEMDiffractionComponent requires the py4DSTEM package.\n", - " warn(\"The TEMDiffractionComponent requires the py4DSTEM package.\")\n", - "/Users/janosh/dev/crystaltoolkit/crystal_toolkit/components/localenv.py:50: UserWarning: Using dscribe SOAP and REMatchKernel requires the dscribe package which was made optional since it in turn requires numba and numba was a common source of installation issues.\n", - " warn(\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "No module named 'phonopy'\n" - ] - }, { "data": { "text/html": [ "\n", " " + "" ] }, "metadata": {}, @@ -392,7 +386,7 @@ " return structure, fig\n", "\n", "\n", - "app.run(mode=\"inline\", height=800, use_reloader=False)" + "app.run(height=800, use_reloader=False)" ] } ], diff --git a/examples/fine_tuning.ipynb b/examples/fine_tuning.ipynb index 632e97a1..a65969b0 100644 --- a/examples/fine_tuning.ipynb +++ b/examples/fine_tuning.ipynb @@ -61,7 +61,7 @@ "CHGNet is interfaced to [Pymatgen](https://pymatgen.org/), the training samples (normally coming from different DFTs like VASP),\n", "need to be converted to [pymatgen.core.structure](https://pymatgen.org/pymatgen.core.html#module-pymatgen.core.structure).\n", "\n", - "To convert VASP calculation to pymatgen structures and CHGNet labels, you can use the following [code](https://github.com/CederGroupHub/chgnet/blob/main/chgnet/utils/vasp_utils.py):" + "To convert VASP calculation to pymatgen structures and CHGNet labels, you can use the following [code](https://github.com/CederGroupHub/chgnet/blob/main/chgnet/utils/vasp_utils.py):\n" ] }, { @@ -79,7 +79,7 @@ "\n", "# ./my_vasp_calc_dir contains vasprun.xml OSZICAR etc.\n", "dataset_dict = parse_vasp_dir(file_root=\"./my_vasp_calc_dir\")\n", - "print(dataset_dict.keys())" + "print(list(dataset_dict))" ] }, { @@ -93,7 +93,7 @@ "\n", "For super-large training dataset, like MPtrj dataset, we recommend [converting them to CHGNet graphs](https://github.com/CederGroupHub/chgnet/blob/main/examples/make_graphs.py). This will save significant memory and graph computing time.\n", "\n", - "Below are the example codes to save the structures." + "Below are the example codes to save the structures.\n" ] }, { @@ -155,7 +155,7 @@ "id": "e1611921", "metadata": {}, "source": [ - "## 1. Prepare Training Data" + "## 1. Prepare Training Data\n" ] }, { From 5468d756f42deec62b95e95ffece7b1b30795baa Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 4 Oct 2023 18:43:29 -0700 Subject: [PATCH 3/5] add option to release to TestPyPI --- .github/workflows/test.yml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b1df8292..7df3130f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -11,9 +11,9 @@ on: inputs: task: type: choice - options: [tests, release] + options: [tests, release, test-release] default: tests - description: Only run tests or release a new version to PyPI after tests pass. + description: Run tests, release to PyPI, or release to TestPyPI. jobs: tests: @@ -108,8 +108,9 @@ jobs: name: artifact path: dist - - name: Publish to PyPi + - name: Publish to PyPi or TestPyPI uses: pypa/gh-action-pypi-publish@release/v1 with: skip-existing: true verbose: true + repository-url: ${{ github.event_name == 'workflow_dispatch' && github.event.inputs.task == 'test-release' && 'https://test.pypi.org/legacy/' || '' }} From 79b23ab8795d8712a469de769a2f25c19e9fb395 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 4 Oct 2023 18:49:56 -0700 Subject: [PATCH 4/5] run CI tests on macos-latest-xlarge hopefully fixes RuntimeError: MPS backend out of memory (MPS allocated: 0 bytes, other allocations: 0 bytes, max allowed: 1.70 GB). Tried to allocate 0 bytes on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure). --- .github/workflows/test.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7df3130f..66cedb7d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -20,7 +20,8 @@ jobs: strategy: fail-fast: true matrix: - os: [ubuntu-latest, macos-latest, windows-latest] + # try macos-latest-xlarge to fix flaky 'RuntimeError: MPS backend out of memory' + os: [ubuntu-latest, macos-latest-xlarge, windows-latest] runs-on: ${{ matrix.os }} steps: @@ -78,6 +79,7 @@ jobs: if: github.event_name == 'release' || (github.event_name == 'workflow_dispatch' && inputs.task == 'release') strategy: matrix: + # try RuntimeError: MPS backend out of memory os: [ubuntu-latest, macos-latest, macos-latest-xlarge, windows-latest] python-version: ["39", "310", "311"] runs-on: ${{ matrix.os }} From 7aa1997be2d3365cd90494026cd74faaff2a5c78 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Fri, 6 Oct 2023 08:23:24 -0700 Subject: [PATCH 5/5] pypa/cibuildwheel set CIBW_ARCHS_MACOS: universal2 --- .github/workflows/test.yml | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 66cedb7d..f85f3cc9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -20,8 +20,7 @@ jobs: strategy: fail-fast: true matrix: - # try macos-latest-xlarge to fix flaky 'RuntimeError: MPS backend out of memory' - os: [ubuntu-latest, macos-latest-xlarge, windows-latest] + os: [ubuntu-latest, macos-latest, windows-latest] runs-on: ${{ matrix.os }} steps: @@ -79,8 +78,7 @@ jobs: if: github.event_name == 'release' || (github.event_name == 'workflow_dispatch' && inputs.task == 'release') strategy: matrix: - # try RuntimeError: MPS backend out of memory - os: [ubuntu-latest, macos-latest, macos-latest-xlarge, windows-latest] + os: [ubuntu-latest, macos-latest, windows-latest] python-version: ["39", "310", "311"] runs-on: ${{ matrix.os }} steps: @@ -91,6 +89,7 @@ jobs: uses: pypa/cibuildwheel@v2.16.2 env: CIBW_BUILD: cp${{ matrix.python-version }}-* + CIBW_ARCHS_MACOS: universal2 - name: Save artifact uses: actions/upload-artifact@v3