Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch compile + export escn #826

Merged
merged 26 commits into from
Sep 10, 2024
Merged

Torch compile + export escn #826

merged 26 commits into from
Sep 10, 2024

Conversation

rayg1234
Copy link
Collaborator

@rayg1234 rayg1234 commented Sep 3, 2024

Compilable and Exportable version of escn.

  • Input is dict[str, tensor] (instead of torch geometric object, requires adapter (currently using data_list_collater) and modification in the trainer to use
  • Removes SO3_Embedding + SO3_Rotation objects
  • Move graph generation to dataloader (using pymatgen), this is equivalent when using large max_neighbors but contains discrepancy when max_neighbors is small (ie: 20)
  • Removes all non-tensor inputs to module functions
  • Removes non export compatible python statements (ie: asserts - which are ok in non-strict mode in torch2.4)
  • Tested to be equivalent to escn on sample inputs

This fully compiles with 0 graph breaks.

Next PR:

  • export still need to support dynamic inputs
  • may need to retrain escn export
  • check exported code works in c++ torch executable

@rayg1234 rayg1234 marked this pull request as ready for review September 9, 2024 22:26
@rayg1234 rayg1234 added enhancement New feature or request minor Minor version release labels Sep 9, 2024
@rayg1234 rayg1234 requested review from kyonofx and lbluque September 9, 2024 22:37
Copy link

codecov bot commented Sep 9, 2024

Codecov Report

Attention: Patch coverage is 93.63817% with 32 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/fairchem/core/models/escn/so3_exportable.py 86.02% 26 Missing ⚠️
src/fairchem/core/models/escn/escn_exportable.py 98.01% 6 Missing ⚠️
Files with missing lines Coverage Δ
src/fairchem/core/common/test_utils.py 89.09% <100.00%> (+0.85%) ⬆️
src/fairchem/core/datasets/lmdb_dataset.py 77.35% <100.00%> (+0.43%) ⬆️
src/fairchem/core/preprocessing/atoms_to_graphs.py 89.14% <100.00%> (+0.71%) ⬆️
src/fairchem/core/models/escn/escn_exportable.py 98.01% <98.01%> (ø)
src/fairchem/core/models/escn/so3_exportable.py 86.02% <86.02%> (ø)

self.use_pbc = use_pbc
self.use_pbc_single = use_pbc_single
self.cutoff = cutoff
self.otf_graph = otf_graph
Copy link
Collaborator

@kyonofx kyonofx Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently this model assumes otf_graph=False. If this is what's going to be in the final version I suggest raising an error if passed otf_graph=True.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ya let me remove this actually

energy = torch.zeros(len(natoms), device=node_energy.device)
energy.index_add_(0, batch_idx, node_energy.view(-1))
# Scale energy to help balance numerical precision w.r.t. forces
energy = energy * 0.001
Copy link
Collaborator

@kyonofx kyonofx Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we want to remove energy = energy * 0.001 eventually but could be kept for backward compat.


# Compare predicted energies and forces (after inv-rotation).
energies = out["energy"].detach()
np.testing.assert_almost_equal(energies[0], energies[1], decimal=5)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with reasonable graphs (no max neighbor limit or no strict max neighbor) this should pass with 6 or even 7 decimal places.

np.testing.assert_array_almost_equal(
forces[: forces.shape[0] // 2],
torch.matmul(forces[forces.shape[0] // 2 :], inv_rot),
decimal=5,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with reasonable graphs (no max neighbor limit or no strict max neighbor) this should pass with 6 or even 7 decimal places.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i tried with energy i can get to 7 decimals but forces only 5

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a single layer should get 7, seems like you are using 8 and that could be why.

np.testing.assert_array_almost_equal(
forces[: forces.shape[0] // 2],
torch.matmul(forces[forces.shape[0] // 2 :], inv_rot),
decimal=5,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a single layer should get 7, seems like you are using 8 and that could be why.

@rayg1234 rayg1234 enabled auto-merge September 10, 2024 19:09
@rayg1234 rayg1234 added this pull request to the merge queue Sep 10, 2024
Merged via the queue into main with commit eddb484 Sep 10, 2024
8 checks passed
@rayg1234 rayg1234 deleted the rgao_torch_compile2 branch September 10, 2024 21:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request minor Minor version release
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants