Skip to content

Commit

Permalink
try to get full edge embeding from gp
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Jul 19, 2024
1 parent 6c36f2c commit 1f1cf0f
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions src/fairchem/core/models/equiformer_v2/equiformer_v2_oc20.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,9 @@ def _forward(self, data):
atomic_numbers = data.atomic_numbers.long()

(
edge_index,
edge_distance,
edge_distance_vec,
edge_index_full,
edge_distance_full,
edge_distance_vec_full,
cell_offsets,
_, # cell offset distances
neighbors,
Expand All @@ -482,9 +482,9 @@ def _forward(self, data):
) = self._init_gp_partitions(
atomic_numbers_full,
data_batch_full,
edge_index,
edge_distance,
edge_distance_vec,
edge_index_full,
edge_distance_full,
edge_distance_vec_full,
)
###############################################################
# Entering Graph Parallel Region
Expand Down Expand Up @@ -583,11 +583,22 @@ def _forward(self, data):
###############################################################
sphharm_weights_edge = o3.spherical_harmonics(
torch.arange(0, x.lmax_list[-1] + 1).tolist(),
edge_distance_vec,
edge_distance_vec_full,
False,
).detach()
print(
edge_index.shape,
edge_index_full.shape,
x.embedding.shape,
edge_distance_vec_full.shape,
sphharm_weights_edge.shape,
)

x_edge = x.expand_edge(edge_index[1]).embedding
#if gp_utils.initialized():
# x_full_embedding = gp_utils.gather_from_model_parallel_region(x.embedding, dim=0)
# #x_
# x_source.set_embedding(x_full)
x_edge = x.expand_edge(edge_index_full[1]).embedding
x_edge = torch.einsum("abc, ab->ac", x_edge, sphharm_weights_edge)

outputs = {
Expand Down

0 comments on commit 1f1cf0f

Please sign in to comment.