Skip to content

Commit

Permalink
layer block
Browse files Browse the repository at this point in the history
  • Loading branch information
rayg1234 committed Aug 30, 2024
1 parent 4017073 commit 52651a6
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 46 deletions.
91 changes: 51 additions & 40 deletions src/fairchem/core/models/escn/escn.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,7 @@ def forward(self, data):
)

# Initialize the WignerD matrices and other values for spherical harmonic calculations
self.SO3_edge_rot = nn.ModuleList()
for i in range(self.num_resolutions):
self.SO3_edge_rot.append(SO3_Rotation(edge_rot_mat, self.lmax_list[i]))
self.SO3_edge_rot = SO3_Rotation(edge_rot_mat, self.lmax_list[0])

###############################################################
# Initialize node embeddings
Expand Down Expand Up @@ -290,29 +288,30 @@ def forward(self, data):
###############################################################
# Update spherical node embeddings
###############################################################

x_message = x.embedding
for i in range(self.num_layers):
if i > 0:
x_message = self.layer_blocks[i](
x,
x_message_new = self.layer_blocks[i](
x_message,
atomic_numbers,
graph.edge_distance,
graph.edge_index,
self.SO3_edge_rot,
)

# Residual layer for all layers past the first
x.embedding = x.embedding + x_message.embedding
x_xessage = x_message + x_message_new

else:
# No residual for the first layer
x = self.layer_blocks[i](
x,
x_message = self.layer_blocks[i](
x_message,
atomic_numbers,
graph.edge_distance,
graph.edge_index,
self.SO3_edge_rot,
)
x.embedding = x_message

# Sample the spherical channels (node embeddings) at evenly distributed points on the sphere.
# These values are fed into the output blocks.
Expand Down Expand Up @@ -448,9 +447,7 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]:
)

# Initialize the WignerD matrices and other values for spherical harmonic calculations
self.SO3_edge_rot = nn.ModuleList()
for i in range(self.num_resolutions):
self.SO3_edge_rot.append(SO3_Rotation(edge_rot_mat, self.lmax_list[i]))
self.SO3_edge_rot = SO3_Rotation(edge_rot_mat, self.lmax_list[0])

###############################################################
# Initialize node embeddings
Expand Down Expand Up @@ -640,12 +637,12 @@ def __init__(

def forward(
self,
x,
atomic_numbers,
edge_distance,
edge_index,
SO3_edge_rot,
):
x: torch.Tensor,
atomic_numbers: torch.Tensor,
edge_distance: torch.Tensor,
edge_index: torch.Tensor,
SO3_edge_rot: SO3_Rotation,
) -> torch.Tensor:
# Compute messages by performing message block
x_message = self.message_block(
x,
Expand All @@ -654,25 +651,33 @@ def forward(
edge_index,
SO3_edge_rot,
)
print(f"x_message: {x_message.mean()}")

# Compute point-wise spherical non-linearity on aggregated messages
max_lmax = max(self.lmax_list)

# Project to grid
x_grid_message = x_message.to_grid(self.SO3_grid["lmax_lmax"])
x_grid = x.to_grid(self.SO3_grid["lmax_lmax"])
# x_grid_message = x_message.to_grid(self.SO3_grid["lmax_lmax"])
to_grid_mat = self.SO3_grid["lmax_lmax"].to_grid_mat[:, :, self.SO3_grid["lmax_lmax"].mapping.coefficient_idx(self.lmax_list[0], self.lmax_list[0])]
x_grid_message = torch.einsum("bai,zic->zbac", to_grid_mat, x_message)

# x_grid = x.to_grid(self.SO3_grid["lmax_lmax"])
x_grid = torch.einsum("bai,zic->zbac", to_grid_mat, x)
x_grid = torch.cat([x_grid, x_grid_message], dim=3)

# Perform point-wise convolution
x_grid = self.act(self.fc1_sphere(x_grid))
x_grid = self.act(self.fc2_sphere(x_grid))
x_grid = self.fc3_sphere(x_grid)
print(f"x_grid: {x_grid.mean()}")

# Project back to spherical harmonic coefficients
x_message._from_grid(x_grid, self.SO3_grid["lmax_lmax"])
# x_message._from_grid(x_grid, self.SO3_grid["lmax_lmax"])
from_grid_mat = self.SO3_grid["lmax_lmax"].from_grid_mat[:, :, self.SO3_grid["lmax_lmax"].mapping.coefficient_idx(self.lmax_list[0], self.lmax_list[0])]
x_message_final = torch.einsum("bai,zbac->zic", from_grid_mat, x_grid)

print(f"x_message_final: {x_message_final.mean()}")
# Return aggregated messages
return x_message
return x_message_final


class MessageBlock(torch.nn.Module):
Expand Down Expand Up @@ -748,12 +753,12 @@ def __init__(

def forward(
self,
x,
atomic_numbers,
edge_distance,
edge_index,
SO3_edge_rot,
):
x: torch.Tensor,
atomic_numbers: torch.Tensor,
edge_distance: torch.Tensor,
edge_index: torch.Tensor,
SO3_edge_rot: SO3_Rotation,
) -> torch.Tensor:
###############################################################
# Compute messages
###############################################################
Expand All @@ -768,30 +773,36 @@ def forward(
# Copy embeddings for each edge's source and target nodes
x_source = x.clone()
x_target = x.clone()
x_source._expand_edge(edge_index[0, :])
x_target._expand_edge(edge_index[1, :])
x_source = x_source[edge_index[0, :]]
x_target = x_target[edge_index[1, :]]

# Rotate the irreps to align with the edge
x_source._rotate(SO3_edge_rot, self.lmax_list, self.mmax_list)
x_target._rotate(SO3_edge_rot, self.lmax_list, self.mmax_list)
x_source = SO3_edge_rot.rotate(x_source, self.lmax_list[0], self.mmax_list[0])
x_target = SO3_edge_rot.rotate(x_target, self.lmax_list[0], self.mmax_list[0])

# Compute messages
x_source.embedding = self.so2_block_source(x_source.embedding, x_edge)
x_target.embedding = self.so2_block_target(x_target.embedding, x_edge)
x_source = self.so2_block_source(x_source, x_edge)
x_target = self.so2_block_target(x_target, x_edge)

# Add together the source and target results
x_target.embedding = x_source.embedding + x_target.embedding
x_target = x_source + x_target

# Point-wise spherical non-linearity
x_target._grid_act(self.SO3_grid["lmax_mmax"], self.act, self.mappingReduced.res_size)
to_grid_mat = self.SO3_grid["lmax_mmax"].get_to_grid_mat()
from_grid_mat = self.SO3_grid["lmax_mmax"].get_from_grid_mat()
x_grid = torch.einsum("bai,zic->zbac", to_grid_mat, x_target)
x_grid = self.act(x_grid)
x_target = torch.einsum("bai,zbac->zic", from_grid_mat, x_grid)

# Rotate back the irreps
x_target._rotate_inv(SO3_edge_rot, self.mappingReduced.res_size)
x_target = SO3_edge_rot.rotate_inv(x_target, self.lmax_list[0], self.mmax_list[0])

# Compute the sum of the incoming neighboring messages for each target node
x_target._reduce_edge(edge_index[1], len(x.embedding))
new_embedding = torch.fill(x.clone(), 0)
new_embedding.index_add_(0, edge_index[1], x_target)
# x_target._reduce_edge(edge_index[1], len(x.embedding))

return x_target
return new_embedding


class SO2Block(torch.nn.Module):
Expand Down
12 changes: 6 additions & 6 deletions tests/core/models/test_escn_compiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def init(backend: str):
init_local_distributed_process_group(backend=backend)

class TestESCNCompiles:
def test_escn_baseline_cpu(self):
def test_escn_baseline_cpu(self, tol=1e-5):
init('gloo')
data = load_data()
data = data_list_collater([data])
Expand All @@ -102,11 +102,11 @@ def test_escn_baseline_cpu(self):
output = ddp_model(data)
expected_energy, expected_forces = expected_energy_forces()
torch.set_printoptions(precision=8)
assert torch.allclose(output["energy"], expected_energy)
assert torch.allclose(output["forces"].mean(0), expected_forces)
assert torch.allclose(output["energy"], expected_energy, atol=tol)
assert torch.allclose(output["forces"].mean(0), expected_forces, atol=tol)

@skip_if_no_cuda
def test_escn_baseline_cuda(self):
def test_escn_baseline_cuda(self, tol=1e-5):
init('nccl')
data = load_data()
data = data_list_collater([data]).to("cuda")
Expand All @@ -115,8 +115,8 @@ def test_escn_baseline_cuda(self):
output = ddp_model(data)
expected_energy, expected_forces = expected_energy_forces_cuda()
torch.set_printoptions(precision=8)
assert torch.allclose(output["energy"].cpu(), expected_energy)
assert torch.allclose(output["forces"].mean(0).cpu(), expected_forces)
assert torch.allclose(output["energy"].cpu(), expected_energy, atol=tol)
assert torch.allclose(output["forces"].mean(0).cpu(), expected_forces, atol=tol)

def test_escn_compiles(self):
init("gloo")
Expand Down

0 comments on commit 52651a6

Please sign in to comment.