Skip to content

Commit

Permalink
Bump torch2.4.1 and pyg (#845)
Browse files Browse the repository at this point in the history
* bump torch and pyg

* fix exported_prog call according to torch error message

* add .module()

* missed one last one

* try with torch 2.4.1

* update yml configs
  • Loading branch information
misko authored Sep 16, 2024
1 parent 83ab799 commit 6696a7e
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 19 deletions.
15 changes: 9 additions & 6 deletions packages/env.cpu.yml
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
channels:
- pytorch
- pyg
- conda-forge
- defaults
dependencies:
- cpuonly
- pytorch>=2
- pyg
- pytorch-scatter
- pytorch-sparse
- pytorch-cluster
- pytorch>=2.4
- ase
- e3nn>=0.5
- numpy >=1.25.0,<2.0.0
- pymatgen>=2023.10.3
- numba
- orjson
- pip
- pip:
- --find-links https://data.pyg.org/whl/torch-2.4.0+cu121.html
- torch_cluster==1.6.3+pt24cpu
- torch_geometric==2.5.3
- pyg-lib==0.4.0+pt24cpu
- torch_scatter==2.1.2+pt24cpu
- torch_sparse==0.6.18+pt24cpu
- torch_spline_conv==1.2.2+pt24cpu
- pyyaml
- tqdm
- python-lmdb
Expand Down
17 changes: 10 additions & 7 deletions packages/env.gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,25 @@ channels:
- pytorch
- nvidia
- conda-forge
- pyg
- defaults
dependencies:
- pytorch-cuda=11.8
- pytorch>=2
- pytorch-scatter
- pytorch-sparse
- pytorch-cluster
- pyg
- pytorch-cuda=12.1
- pytorch>=2.4
- ase
- e3nn>=0.5
- numpy >=1.25.0,<2.0.0
- pymatgen>=2023.10.3
- numba
- orjson
- pip
- pip:
- --find-links https://data.pyg.org/whl/torch-2.4.0+cu121.html
- torch_cluster==1.6.3+pt24cu121
- torch_geometric==2.5.3
- pyg-lib==0.4.0+pt24cu121
- torch_scatter==2.1.2+pt24cu121
- torch_sparse==0.6.18+pt24cu121
- torch_spline_conv==1.2.2+pt24cu121
- pyyaml
- tqdm
- python-lmdb
Expand Down
4 changes: 2 additions & 2 deletions packages/requirements-optional.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
torch_geometric==2.3.0
-f https://data.pyg.org/whl/torch-2.2.0+cpu.html
torch_geometric==2.5.3
-f https://data.pyg.org/whl/torch-2.4.0+cpu.html
torch_scatter==2.1.2
torch_sparse==0.6.18
torch_cluster==1.6.3
2 changes: 1 addition & 1 deletion packages/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
torch==2.2.0
torch==2.4.1
numpy==1.23.5
ase==3.23.0
6 changes: 3 additions & 3 deletions tests/core/models/test_escn_compiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def test_escn_message_block_exports_and_compiles(self, tol=1e-5) -> None:
compiled_output = compiled_model(*args[0])

exported_prog = export(message_block, args=args[0])
exported_output = exported_prog(*args[0])
exported_output = exported_prog.module()(*args[0])

regular_out = message_block(*args[0])
assert torch.allclose(compiled_output, regular_out, atol=tol)
Expand Down Expand Up @@ -302,7 +302,7 @@ def test_escn_layer_block_exports_and_compiles(self, tol=1e-5) -> None:
}
exported_prog = export(layer_block, args=run_args[0], dynamic_shapes=dynamic_shapes1)
for run_arg in run_args:
exported_output = exported_prog(*run_arg)
exported_output = exported_prog.module()(*run_arg)
compiled_model = torch.compile(layer_block, dynamic=True)
compiled_output = compiled_model(*run_arg)
regular_out = layer_block(*run_arg)
Expand Down Expand Up @@ -343,7 +343,7 @@ def test_full_escn_exports(self):
# print(explained_output)
# TODO: add dynamic shapes
exported_prog = export(exportable_model, args=(export_data,))
export_output = exported_prog(export_data)
export_output = exported_prog.module()(export_data)
expected_output = escn_model(regular_data)
assert torch.allclose(export_output["energy"], expected_output["energy"])
assert torch.allclose(export_output["forces"].mean(0), expected_output["forces"].mean(0))

0 comments on commit 6696a7e

Please sign in to comment.