Skip to content

Commit

Permalink
compile works, guard failures still
Browse files Browse the repository at this point in the history
  • Loading branch information
rayg1234 committed Aug 29, 2024
1 parent cdb4410 commit e4a426a
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 27 deletions.
28 changes: 16 additions & 12 deletions src/fairchem/core/models/escn/escn.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,18 @@ def __init__(
)

# Initialize the transformations between spherical and grid representations
self.SO3_grid = nn.ModuleList()
for lval in range(max(self.lmax_list) + 1):
SO3_m_grid = nn.ModuleList()
for m in range(max(self.lmax_list) + 1):
SO3_m_grid.append(SO3_Grid(lval, m, resolution=resolution))

self.SO3_grid.append(SO3_m_grid)

assert self.num_resolutions == 1, "Only one resolution is supported"
self.SO3_grid = nn.ModuleDict()
self.SO3_grid["lmax_lmax"] = SO3_Grid(self.lmax_list[0], self.lmax_list[0], resolution=resolution)
self.SO3_grid["lmax_mmax"] = SO3_Grid(self.lmax_list[0], self.mmax_list[0], resolution=resolution)
# self.SO3_grid = nn.ModuleList()
# for lval in range(max(self.lmax_list) + 1):
# SO3_m_grid = nn.ModuleList()
# for m in range(max(self.lmax_list) + 1):
# SO3_m_grid.append(SO3_Grid(lval, m, resolution=resolution))

# self.SO3_grid.append(SO3_m_grid)
# import pdb;pdb.set_trace()
self.mappingReduced = CoefficientMapping(self.lmax_list, self.mmax_list)

# Initialize the blocks for each layer of the GNN
Expand Down Expand Up @@ -655,8 +659,8 @@ def forward(
max_lmax = max(self.lmax_list)

# Project to grid
x_grid_message = x_message.to_grid(self.SO3_grid, lmax=max_lmax)
x_grid = x.to_grid(self.SO3_grid, lmax=max_lmax)
x_grid_message = x_message.to_grid(self.SO3_grid["lmax_lmax"])
x_grid = x.to_grid(self.SO3_grid["lmax_lmax"])
x_grid = torch.cat([x_grid, x_grid_message], dim=3)

# Perform point-wise convolution
Expand All @@ -665,7 +669,7 @@ def forward(
x_grid = self.fc3_sphere(x_grid)

# Project back to spherical harmonic coefficients
x_message._from_grid(x_grid, self.SO3_grid, lmax=max_lmax)
x_message._from_grid(x_grid, self.SO3_grid["lmax_lmax"])

# Return aggregated messages
return x_message
Expand Down Expand Up @@ -780,7 +784,7 @@ def forward(
x_target.embedding = x_source.embedding + x_target.embedding

# Point-wise spherical non-linearity
x_target._grid_act(self.SO3_grid, self.act, self.mappingReduced.res_size)
x_target._grid_act(self.SO3_grid["lmax_mmax"], self.act, self.mappingReduced.res_size)

# Rotate back the irreps
x_target._rotate_inv(SO3_edge_rot, self.mappingReduced.res_size)
Expand Down
21 changes: 7 additions & 14 deletions src/fairchem/core/models/escn/so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,12 +274,8 @@ def _grid_act(self, SO3_grid, act, res_size) -> None:
num_coefficients = res_size[i]

x_res = self.embedding[:, offset : offset + num_coefficients].contiguous()
to_grid_mat = SO3_grid[self.lmax_list[i]][
self.mmax_list[i]
].get_to_grid_mat(self.device)
from_grid_mat = SO3_grid[self.lmax_list[i]][
self.mmax_list[i]
].get_from_grid_mat(self.device)
to_grid_mat = SO3_grid.get_to_grid_mat(self.device)
from_grid_mat = SO3_grid.get_from_grid_mat(self.device)

x_grid = torch.einsum("bai,zic->zbac", to_grid_mat, x_res)
x_grid = act(x_grid)
Expand All @@ -293,8 +289,8 @@ def to_grid(self, SO3_grid, lmax: int = -1) -> torch.Tensor:
if lmax == -1:
lmax = max(self.lmax_list)

to_grid_mat_lmax = SO3_grid[lmax][lmax].get_to_grid_mat(self.device)
grid_mapping = SO3_grid[lmax][lmax].mapping
to_grid_mat_lmax = SO3_grid.get_to_grid_mat(self.device)
grid_mapping = SO3_grid.mapping

offset = 0
x_grid = torch.tensor([], device=self.device)
Expand All @@ -316,12 +312,9 @@ def to_grid(self, SO3_grid, lmax: int = -1) -> torch.Tensor:
return x_grid

# Compute irreps from grid representation
def _from_grid(self, x_grid, SO3_grid, lmax: int = -1) -> None:
if lmax == -1:
lmax = max(self.lmax_list)

from_grid_mat_lmax = SO3_grid[lmax][lmax].get_from_grid_mat(self.device)
grid_mapping = SO3_grid[lmax][lmax].mapping
def _from_grid(self, x_grid, SO3_grid) -> None:
from_grid_mat_lmax = SO3_grid.get_from_grid_mat(self.device)
grid_mapping = SO3_grid.mapping

offset = 0
offset_channel = 0
Expand Down
5 changes: 4 additions & 1 deletion tests/core/models/test_escn_compiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ def test_escn_compiles(self):
torch._dynamo.config.optimize_ddp = False
# torch._dynamo.explain(model)(data)
# assert False
compiled_model(data)
output = compiled_model(data)
expected_energy, expected_forces = expected_energy_forces()
assert torch.allclose(output["energy"], expected_energy)
assert torch.allclose(output["forces"].mean(0), expected_forces)

def test_rotation_invariance(self) -> None:
random.seed(1)
Expand Down

0 comments on commit e4a426a

Please sign in to comment.