Skip to content

Commit

Permalink
remove eqv2 stuff from this branch
Browse files Browse the repository at this point in the history
  • Loading branch information
rayg1234 committed Sep 2, 2024
1 parent cb679dd commit 2250fa6
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 463 deletions.
140 changes: 0 additions & 140 deletions src/fairchem/core/models/equiformer_v2/so2_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def __init__(
self.rad_func = RadialFunction(self.edge_channels_list)

def forward(self, x, x_edge):

num_edges = len(x_edge)
out = []

Expand Down Expand Up @@ -353,142 +352,3 @@ def forward(self, x, x_edge):
out_embedding._l_primary(self.mappingReduced)

return out_embedding

class SO2_Convolution_Exportable(torch.nn.Module):
"""
SO(2) Block: Perform SO(2) convolutions for all m (orders)
Args:
sphere_channels (int): Number of spherical channels
m_output_channels (int): Number of output channels used during the SO(2) conv
lmax_list (list:int): List of degrees (l) for each resolution
mmax_list (list:int): List of orders (m) for each resolution
mappingReduced (CoefficientMappingModule): Used to extract a subset of m components
internal_weights (bool): If True, not using radial function to multiply inputs features
edge_channels_list (list:int): List of sizes of invariant edge embedding. For example, [input_channels, hidden_channels, hidden_channels].
extra_m0_output_channels (int): If not None, return `out_embedding` (SO3_Embedding) and `extra_m0_features` (Tensor).
"""

def __init__(
self,
sphere_channels: int,
m_output_channels: int,
lmax_list: list[int],
mmax_list: list[int],
mappingReduced,
internal_weights: bool = True,
edge_channels_list: list[int] | None = None,
extra_m0_output_channels: int | None = None,
):
super().__init__()
self.sphere_channels = sphere_channels
self.m_output_channels = m_output_channels
self.lmax_list = lmax_list
self.mmax_list = mmax_list
self.mappingReduced = mappingReduced
self.num_resolutions = len(lmax_list)
self.internal_weights = internal_weights
self.edge_channels_list = copy.deepcopy(edge_channels_list)
self.extra_m0_output_channels = extra_m0_output_channels

num_channels_rad = 0 # for radial function

num_channels_m0 = 0
for i in range(self.num_resolutions):
num_coefficients = self.lmax_list[i] + 1
num_channels_m0 = num_channels_m0 + num_coefficients * self.sphere_channels

# SO(2) convolution for m = 0
m0_output_channels = self.m_output_channels * (
num_channels_m0 // self.sphere_channels
)
if self.extra_m0_output_channels is not None:
m0_output_channels = m0_output_channels + self.extra_m0_output_channels
self.fc_m0 = Linear(num_channels_m0, m0_output_channels)
num_channels_rad = num_channels_rad + self.fc_m0.in_features

# SO(2) convolution for non-zero m
self.so2_m_conv = nn.ModuleList()
for m in range(1, max(self.mmax_list) + 1):
self.so2_m_conv.append(
SO2_m_Convolution(
m,
self.sphere_channels,
self.m_output_channels,
self.lmax_list,
self.mmax_list,
)
)
num_channels_rad = num_channels_rad + self.so2_m_conv[-1].fc.in_features

# Embedding function of distance
self.rad_func = None
if not self.internal_weights:
assert self.edge_channels_list is not None
self.edge_channels_list.append(int(num_channels_rad))
self.rad_func = RadialFunction(self.edge_channels_list)

def forward(self, x_emb, x_edge):
# x_emb: [num_edges, num_sh_coefs, num_features]
# x_edge: [num_edges, num_edge_features]

num_edges = x_edge.shape[0]
out = []
# torch export does not inputs based on a buffered tensor
m_size = self.mappingReduced.m_size

# Reshape the spherical harmonics based on m (order), equivalent to x._m_primary
x_emb = torch.einsum("nac, ba -> nbc", x_emb, self.mappingReduced.to_m)

# radial function
if self.rad_func is not None:
x_edge = self.rad_func(x_edge)
offset_rad = 0

# Compute m=0 coefficients separately since they only have real values (no imaginary)
x_0 = x_emb.narrow(1, 0, m_size[0])
x_0 = x_0.reshape(x_edge.shape[0], -1)
if self.rad_func is not None:
x_edge_0 = x_edge.narrow(1, 0, self.fc_m0.in_features)
x_0 = x_0 * x_edge_0
x_0 = self.fc_m0(x_0)

x_0_extra = None
# extract extra m0 features
if self.extra_m0_output_channels is not None:
x_0_extra = x_0.narrow(-1, 0, self.extra_m0_output_channels)
x_0 = x_0.narrow(
-1,
self.extra_m0_output_channels,
(self.fc_m0.out_features - self.extra_m0_output_channels),
)

x_0 = x_0.view(num_edges, -1, self.m_output_channels)
out.append(x_0)
offset_rad = offset_rad + self.fc_m0.in_features

# Compute the values for the m > 0 coefficients
offset = m_size[0]
for m in range(1, max(self.mmax_list) + 1):
# Get the m order coefficients
x_m = x_emb.narrow(1, offset, 2 * m_size[m])
x_m = x_m.reshape(num_edges, 2, -1)

# Perform SO(2) convolution
if self.rad_func is not None:
x_edge_m = x_edge.narrow(
1, offset_rad, self.so2_m_conv[m - 1].fc.in_features
)
x_edge_m = x_edge_m.reshape(
num_edges, 1, self.so2_m_conv[m - 1].fc.in_features
)
x_m = x_m * x_edge_m
x_m = self.so2_m_conv[m - 1](x_m)
x_m = x_m.view(num_edges, -1, self.m_output_channels)
out.append(x_m)
offset = offset + 2 * m_size[m]
offset_rad = offset_rad + self.so2_m_conv[m - 1].fc.in_features

out = torch.cat(out, dim=1)
out = torch.einsum("nac, ab -> nbc", out, self.mappingReduced.to_m)
return out
Loading

0 comments on commit 2250fa6

Please sign in to comment.