diff --git a/src/fairchem/core/common/test_utils.py b/src/fairchem/core/common/test_utils.py index 130daba2d..ce86aa782 100644 --- a/src/fairchem/core/common/test_utils.py +++ b/src/fairchem/core/common/test_utils.py @@ -130,3 +130,13 @@ def spawn_multi_process( ) return [mp_output_dict[i] for i in range(config.world_size)] + +def init_local_distributed_process_group(backend="nccl"): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(get_free_port()) + dist.init_process_group( + rank=0, + world_size=1, + backend=backend, + timeout=timedelta(seconds=10), # setting up timeout for distributed collectives + ) diff --git a/src/fairchem/core/datasets/lmdb_dataset.py b/src/fairchem/core/datasets/lmdb_dataset.py index ca1fcc2b7..346987d8e 100644 --- a/src/fairchem/core/datasets/lmdb_dataset.py +++ b/src/fairchem/core/datasets/lmdb_dataset.py @@ -211,7 +211,7 @@ def sample_property_metadata(self, num_samples: int = 100): } -def data_list_collater(data_list: list[BaseData], otf_graph: bool = False) -> BaseData: +def data_list_collater(data_list: list[BaseData], otf_graph: bool = False, to_dict: bool = False) -> BaseData | dict[str, torch.Tensor]: batch = Batch.from_data_list(data_list) if not otf_graph: @@ -226,4 +226,7 @@ def data_list_collater(data_list: list[BaseData], otf_graph: bool = False) -> Ba "LMDB does not contain edge index information, set otf_graph=True" ) + if to_dict: + batch = dict(batch.items()) + return batch diff --git a/src/fairchem/core/models/escn/escn_exportable.py b/src/fairchem/core/models/escn/escn_exportable.py new file mode 100644 index 000000000..c1a40ff59 --- /dev/null +++ b/src/fairchem/core/models/escn/escn_exportable.py @@ -0,0 +1,855 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import contextlib +import logging + +import torch +import torch.nn as nn + +from fairchem.core.common.registry import registry +from fairchem.core.models.escn.so3_exportable import ( + CoefficientMapping, + SO3_Grid, + rotation_to_wigner, +) +from fairchem.core.models.scn.sampling import CalcSpherePoints +from fairchem.core.models.scn.smearing import ( + GaussianSmearing, + LinearSigmoidSmearing, + SigmoidSmearing, + SiLUSmearing, +) + +with contextlib.suppress(ImportError): + from e3nn import o3 + + +@registry.register_model("escn_export") +class eSCN(nn.Module): + """Equivariant Spherical Channel Network + Paper: Reducing SO(3) Convolutions to SO(2) for Efficient Equivariant GNNs + + + Args: + regress_forces (bool): Compute forces + cutoff (float): Maximum distance between nieghboring atoms in Angstroms + max_num_elements (int): Maximum atomic number + num_layers (int): Number of layers in the GNN + lmax (int): maximum degree of the spherical harmonics (1 to 10) + mmax (int): maximum order of the spherical harmonics (0 to lmax) + sphere_channels (int): Number of spherical channels (one set per resolution) + hidden_channels (int): Number of hidden units in message passing + num_sphere_samples (int): Number of samples used to approximate the integration of the sphere in the output blocks + edge_channels (int): Number of channels for the edge invariant features + distance_function ("gaussian", "sigmoid", "linearsigmoid", "silu"): Basis function used for distances + basis_width_scalar (float): Width of distance basis function + distance_resolution (float): Distance between distance basis functions in Angstroms + """ + + def __init__( + self, + regress_forces: bool = True, + cutoff: float = 8.0, + max_num_elements: int = 90, + num_layers: int = 8, + lmax: int = 4, + mmax: int = 2, + sphere_channels: int = 128, + hidden_channels: int = 256, + edge_channels: int = 128, + num_sphere_samples: int = 128, + distance_function: str = "gaussian", + basis_width_scalar: float = 1.0, + distance_resolution: float = 0.02, + resolution: int | None = None, + ) -> None: + super().__init__() + + import sys + + if "e3nn" not in sys.modules: + logging.error("You need to install the e3nn library to use the SCN model") + raise ImportError + + self.regress_forces = regress_forces + self.cutoff = cutoff + self.max_num_elements = max_num_elements + self.hidden_channels = hidden_channels + self.num_layers = num_layers + self.num_sphere_samples = num_sphere_samples + self.sphere_channels = sphere_channels + self.edge_channels = edge_channels + self.distance_resolution = distance_resolution + self.lmax = lmax + self.mmax = mmax + self.basis_width_scalar = basis_width_scalar + self.distance_function = distance_function + + # non-linear activation function used throughout the network + self.act = nn.SiLU() + + # Weights for message initialization + self.sphere_embedding = nn.Embedding( + self.max_num_elements, self.sphere_channels + ) + + # Initialize the function used to measure the distances between atoms + assert self.distance_function in [ + "gaussian", + "sigmoid", + "linearsigmoid", + "silu", + ] + self.num_gaussians = int(cutoff / self.distance_resolution) + if self.distance_function == "gaussian": + self.distance_expansion = GaussianSmearing( + 0.0, + cutoff, + self.num_gaussians, + basis_width_scalar, + ) + if self.distance_function == "sigmoid": + self.distance_expansion = SigmoidSmearing( + 0.0, + cutoff, + self.num_gaussians, + basis_width_scalar, + ) + if self.distance_function == "linearsigmoid": + self.distance_expansion = LinearSigmoidSmearing( + 0.0, + cutoff, + self.num_gaussians, + basis_width_scalar, + ) + if self.distance_function == "silu": + self.distance_expansion = SiLUSmearing( + 0.0, + cutoff, + self.num_gaussians, + basis_width_scalar, + ) + + # Initialize the transformations between spherical and grid representations + self.SO3_grid = nn.ModuleDict() + self.SO3_grid["lmax_lmax"] = SO3_Grid(self.lmax, self.lmax, resolution=resolution) + self.SO3_grid["lmax_mmax"] = SO3_Grid(self.lmax, self.mmax, resolution=resolution) + self.mappingReduced = CoefficientMapping([self.lmax], [self.mmax]) + + # Initialize the blocks for each layer of the GNN + self.layer_blocks = nn.ModuleList() + for i in range(self.num_layers): + block = LayerBlock( + i, + self.sphere_channels, + self.hidden_channels, + self.edge_channels, + self.lmax, + self.mmax, + self.distance_expansion, + self.max_num_elements, + self.SO3_grid, + self.act, + self.mappingReduced + ) + self.layer_blocks.append(block) + + # Output blocks for energy and forces + self.energy_block = EnergyBlock( + self.sphere_channels, self.num_sphere_samples, self.act + ) + if self.regress_forces: + self.force_block = ForceBlock( + self.sphere_channels, self.num_sphere_samples, self.act + ) + + # Create a roughly evenly distributed point sampling of the sphere for the output blocks + self.sphere_points = nn.Parameter( + CalcSpherePoints(self.num_sphere_samples), requires_grad=False + ) + + # For each spherical point, compute the spherical harmonic coefficient weights + self.sphharm_weights: nn.Parameter = nn.Parameter( + o3.spherical_harmonics( + torch.arange(0, self.lmax + 1).tolist(), + self.sphere_points, + False, + ), + requires_grad=False, + ) + + + def forward(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + pos: torch.Tensor = data["pos"] + batch_idx: torch.Tensor = data["batch"] + natoms: torch.Tensor = data["natoms"] + atomic_numbers: torch.Tensor = data["atomic_numbers"] + edge_index: torch.Tensor = data["edge_index"] + edge_distance: torch.Tensor = data["distances"] + edge_distance_vec: torch.Tensor = data["edge_distance_vec"] + + atomic_numbers = atomic_numbers.long() + # TODO: this requires upgrade to torch2.4 with export non-strict mode to enable + # assert ( + # atomic_numbers.max().item() < self.max_num_elements + # ), "Atomic number exceeds that given in model config" + num_atoms = len(atomic_numbers) + + ############################################################### + # Initialize data structures + ############################################################### + + # Compute 3x3 rotation matrix per edge + edge_rot_mat = self._init_edge_rot_mat( + edge_index, edge_distance_vec + ) + wigner = rotation_to_wigner(edge_rot_mat, 0, self.lmax).detach() + + ############################################################### + # Initialize node embeddings + ############################################################### + + # Init per node representations using an atomic number based embedding + x_message = torch.zeros( + num_atoms, + int((self.lmax + 1) ** 2), + self.sphere_channels, + device=pos.device, + dtype=pos.dtype, + ) + x_message[:, 0, :] = self.sphere_embedding(atomic_numbers) + + ############################################################### + # Update spherical node embeddings + ############################################################### + for i in range(self.num_layers): + if i > 0: + x_message_new = self.layer_blocks[i]( + x_message, + atomic_numbers, + edge_distance, + edge_index, + wigner, + ) + # Residual layer for all layers past the first + x_message = x_message + x_message_new + else: + # No residual for the first layer + x_message = self.layer_blocks[i]( + x_message, + atomic_numbers, + edge_distance, + edge_index, + wigner, + ) + + # Sample the spherical channels (node embeddings) at evenly distributed points on the sphere. + # These values are fed into the output blocks. + x_pt = torch.einsum("abc, pb->apc", x_message, self.sphharm_weights).contiguous() + + ############################################################### + # Energy estimation + ############################################################### + node_energy = self.energy_block(x_pt) + energy = torch.zeros(len(natoms), device=node_energy.device) + energy.index_add_(0, batch_idx, node_energy.view(-1)) + # Scale energy to help balance numerical precision w.r.t. forces + energy = energy * 0.001 + + outputs = {"energy": energy} + ############################################################### + # Force estimation + ############################################################### + if self.regress_forces: + forces = self.force_block(x_pt, self.sphere_points) + outputs["forces"] = forces + + return outputs + + # Initialize the edge rotation matrics + def _init_edge_rot_mat(self, edge_index, edge_distance_vec): + edge_vec_0 = edge_distance_vec + edge_vec_0_distance = torch.sqrt(torch.sum(edge_vec_0**2, dim=1)) + + # Make sure the atoms are far enough apart + # TODO: this requires upgrade to torch2.4 with export non-strict mode to enable + # if torch.min(edge_vec_0_distance) < 0.0001: + # logging.error( + # f"Error edge_vec_0_distance: {torch.min(edge_vec_0_distance)}" + # ) + # (minval, minidx) = torch.min(edge_vec_0_distance, 0) + # logging.error( + # f"Error edge_vec_0_distance: {minidx} {edge_index[0, minidx]} {edge_index[1, minidx]} {data.pos[edge_index[0, minidx]]} {data.pos[edge_index[1, minidx]]}" + # ) + + norm_x = edge_vec_0 / (edge_vec_0_distance.view(-1, 1)) + + edge_vec_2 = torch.rand_like(edge_vec_0) - 0.5 + edge_vec_2 = edge_vec_2 / ( + torch.sqrt(torch.sum(edge_vec_2**2, dim=1)).view(-1, 1) + ) + # Create two rotated copys of the random vectors in case the random vector is aligned with norm_x + # With two 90 degree rotated vectors, at least one should not be aligned with norm_x + edge_vec_2b = edge_vec_2.clone() + edge_vec_2b[:, 0] = -edge_vec_2[:, 1] + edge_vec_2b[:, 1] = edge_vec_2[:, 0] + edge_vec_2c = edge_vec_2.clone() + edge_vec_2c[:, 1] = -edge_vec_2[:, 2] + edge_vec_2c[:, 2] = edge_vec_2[:, 1] + vec_dot_b = torch.abs(torch.sum(edge_vec_2b * norm_x, dim=1)).view(-1, 1) + vec_dot_c = torch.abs(torch.sum(edge_vec_2c * norm_x, dim=1)).view(-1, 1) + + vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)).view(-1, 1) + edge_vec_2 = torch.where(torch.gt(vec_dot, vec_dot_b), edge_vec_2b, edge_vec_2) + vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)).view(-1, 1) + edge_vec_2 = torch.where(torch.gt(vec_dot, vec_dot_c), edge_vec_2c, edge_vec_2) + + vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)) + # Check the vectors aren't aligned + assert torch.max(vec_dot) < 0.99 + + norm_z = torch.cross(norm_x, edge_vec_2, dim=1) + norm_z = norm_z / (torch.sqrt(torch.sum(norm_z**2, dim=1, keepdim=True))) + norm_z = norm_z / (torch.sqrt(torch.sum(norm_z**2, dim=1)).view(-1, 1)) + norm_y = torch.cross(norm_x, norm_z, dim=1) + norm_y = norm_y / (torch.sqrt(torch.sum(norm_y**2, dim=1, keepdim=True))) + + # Construct the 3D rotation matrix + norm_x = norm_x.view(-1, 3, 1) + norm_y = -norm_y.view(-1, 3, 1) + norm_z = norm_z.view(-1, 3, 1) + + edge_rot_mat_inv = torch.cat([norm_z, norm_x, norm_y], dim=2) + edge_rot_mat = torch.transpose(edge_rot_mat_inv, 1, 2) + + return edge_rot_mat.detach() + + @property + def num_params(self) -> int: + return sum(p.numel() for p in self.parameters()) + +class LayerBlock(torch.nn.Module): + """ + Layer block: Perform one layer (message passing and aggregation) of the GNN + + Args: + layer_idx (int): Layer number + sphere_channels (int): Number of spherical channels + hidden_channels (int): Number of hidden channels used during the SO(2) conv + edge_channels (int): Size of invariant edge embedding + lmax (int) degrees (l) for each resolution + mmax (int): orders (m) for each resolution + distance_expansion (func): Function used to compute distance embedding + max_num_elements (int): Maximum number of atomic numbers + SO3_grid (SO3_grid): Class used to convert from grid the spherical harmonic representations + act (function): Non-linear activation function + """ + + def __init__( + self, + layer_idx: int, + sphere_channels: int, + hidden_channels: int, + edge_channels: int, + lmax: int, + mmax: int, + distance_expansion, + max_num_elements: int, + SO3_grid: SO3_Grid, + act, + mappingReduced, + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.act = act + self.lmax = lmax + self.mmax = mmax + self.sphere_channels = sphere_channels + self.SO3_grid = SO3_grid + self.mappingReduced = mappingReduced + + # Message block + self.message_block = MessageBlock( + self.layer_idx, + self.sphere_channels, + hidden_channels, + edge_channels, + self.lmax, + self.mmax, + distance_expansion, + max_num_elements, + self.SO3_grid, + self.act, + self.mappingReduced + ) + + # Non-linear point-wise comvolution for the aggregated messages + self.fc1_sphere = nn.Linear( + 2 * self.sphere_channels, self.sphere_channels, bias=False + ) + + self.fc2_sphere = nn.Linear( + self.sphere_channels, self.sphere_channels, bias=False + ) + + self.fc3_sphere = nn.Linear( + self.sphere_channels, self.sphere_channels, bias=False + ) + + def forward( + self, + x: torch.Tensor, + atomic_numbers: torch.Tensor, + edge_distance: torch.Tensor, + edge_index: torch.Tensor, + wigner: torch.Tensor, + ) -> torch.Tensor: + # Compute messages by performing message block + x_message = self.message_block( + x, + atomic_numbers, + edge_distance, + edge_index, + wigner, + ) + + # Compute point-wise spherical non-linearity on aggregated messages + # Project to grid + to_grid_mat = self.SO3_grid["lmax_lmax"].to_grid_mat[:, :, self.SO3_grid["lmax_lmax"].mapping.coefficient_idx(self.lmax, self.lmax)] + x_grid_message = torch.einsum("bai,zic->zbac", to_grid_mat, x_message) + + # x_grid = x.to_grid(self.SO3_grid["lmax_lmax"]) + x_grid = torch.einsum("bai,zic->zbac", to_grid_mat, x) + x_grid = torch.cat([x_grid, x_grid_message], dim=3) + + # Perform point-wise convolution + x_grid = self.act(self.fc1_sphere(x_grid)) + x_grid = self.act(self.fc2_sphere(x_grid)) + x_grid = self.fc3_sphere(x_grid) + + # Project back to spherical harmonic coefficients + from_grid_mat = self.SO3_grid["lmax_lmax"].from_grid_mat[:, :, self.SO3_grid["lmax_lmax"].mapping.coefficient_idx(self.lmax, self.lmax)] + return torch.einsum("bai,zbac->zic", from_grid_mat, x_grid) + + +class MessageBlock(torch.nn.Module): + """ + Message block: Perform message passing + + Args: + layer_idx (int): Layer number + sphere_channels (int): Number of spherical channels + hidden_channels (int): Number of hidden channels used during the SO(2) conv + edge_channels (int): Size of invariant edge embedding + lmax (int): degrees (l) for each resolution + mmax (int): orders (m) for each resolution + distance_expansion (func): Function used to compute distance embedding + max_num_elements (int): Maximum number of atomic numbers + SO3_grid (SO3_grid): Class used to convert from grid the spherical harmonic representations + act (function): Non-linear activation function + """ + + def __init__( + self, + layer_idx: int, + sphere_channels: int, + hidden_channels: int, + edge_channels: int, + lmax: int, + mmax: int, + distance_expansion, + max_num_elements: int, + SO3_grid: SO3_Grid, + act, + mappingReduced, + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.act = act + self.hidden_channels = hidden_channels + self.sphere_channels = sphere_channels + self.SO3_grid = SO3_grid + self.lmax = lmax + self.mmax = mmax + self.edge_channels = edge_channels + self.mappingReduced = mappingReduced + self.out_mask = self.mappingReduced.coefficient_idx(self.lmax, self.mmax) + + # Create edge scalar (invariant to rotations) features + self.edge_block = EdgeBlock( + self.edge_channels, + distance_expansion, + max_num_elements, + self.act, + ) + + # Create SO(2) convolution blocks + self.so2_block_source = SO2Block( + self.sphere_channels, + self.hidden_channels, + self.edge_channels, + self.lmax, + self.mmax, + self.act, + self.mappingReduced + ) + self.so2_block_target = SO2Block( + self.sphere_channels, + self.hidden_channels, + self.edge_channels, + self.lmax, + self.mmax, + self.act, + self.mappingReduced + ) + + def forward( + self, + x: torch.Tensor, + atomic_numbers: torch.Tensor, + edge_distance: torch.Tensor, + edge_index: torch.Tensor, + wigner: torch.Tensor, + ) -> torch.Tensor: + ############################################################### + # Compute messages + ############################################################### + # Compute edge scalar features (invariant to rotations) + # Uses atomic numbers and edge distance as inputs + x_edge = self.edge_block( + edge_distance, + atomic_numbers[edge_index[0]], # Source atom atomic number + atomic_numbers[edge_index[1]], # Target atom atomic number + ) + + # Copy embeddings for each edge's source and target nodes + x_source = x.clone() + x_target = x.clone() + x_source = x_source[edge_index[0, :]] + x_target = x_target[edge_index[1, :]] + + # Rotate the irreps to align with the edge + x_source = torch.bmm(wigner[:, self.out_mask, :], x_source) + x_target = torch.bmm(wigner[:, self.out_mask, :], x_target) + + # Compute messages + x_source = self.so2_block_source(x_source, x_edge) + x_target = self.so2_block_target(x_target, x_edge) + + # Add together the source and target results + x_target = x_source + x_target + + # Point-wise spherical non-linearity + to_grid_mat = self.SO3_grid["lmax_mmax"].get_to_grid_mat() + from_grid_mat = self.SO3_grid["lmax_mmax"].get_from_grid_mat() + x_grid = torch.einsum("bai,zic->zbac", to_grid_mat, x_target) + x_grid = self.act(x_grid) + x_target = torch.einsum("bai,zbac->zic", from_grid_mat, x_grid) + + # Rotate back the irreps + wigner_inv = torch.transpose(wigner, 1, 2).contiguous().detach() + x_target = torch.bmm(wigner_inv[:, :, self.out_mask], x_target) + + # Compute the sum of the incoming neighboring messages for each target node + new_embedding = torch.zeros(x.shape, dtype=x_target.dtype, device=x_target.device) + new_embedding.index_add_(0, edge_index[1], x_target) + + return new_embedding + + +class SO2Block(torch.nn.Module): + """ + SO(2) Block: Perform SO(2) convolutions for all m (orders) + + Args: + sphere_channels (int): Number of spherical channels + hidden_channels (int): Number of hidden channels used during the SO(2) conv + edge_channels (int): Size of invariant edge embedding + lmax (int): degrees (l) for each resolution + mmax (int): orders (m) for each resolution + act (function): Non-linear activation function + """ + + def __init__( + self, + sphere_channels: int, + hidden_channels: int, + edge_channels: int, + lmax: int, + mmax: int, + act, + mappingReduced + ) -> None: + super().__init__() + self.sphere_channels = sphere_channels + self.hidden_channels = hidden_channels + self.lmax = lmax + self.mmax = mmax + self.act = act + self.mappingReduced = mappingReduced + + num_channels_m0 = (self.lmax + 1) * self.sphere_channels + + # SO(2) convolution for m=0 + self.fc1_dist0 = nn.Linear(edge_channels, self.hidden_channels) + self.fc1_m0 = nn.Linear(num_channels_m0, self.hidden_channels, bias=False) + self.fc2_m0 = nn.Linear(self.hidden_channels, num_channels_m0, bias=False) + + # SO(2) convolution for non-zero m + self.so2_conv = nn.ModuleList() + for m in range(1, self.mmax + 1): + so2_conv = SO2Conv( + m, + self.sphere_channels, + self.hidden_channels, + edge_channels, + self.lmax, + self.mmax, + self.act, + ) + self.so2_conv.append(so2_conv) + + def forward( + self, + x: torch.Tensor, + x_edge: torch.Tensor, + ): + num_edges = len(x_edge) + + # Reshape the spherical harmonics based on m (order) + x = torch.einsum("nac,ba->nbc", x, self.mappingReduced.to_m) + + # Compute m=0 coefficients separately since they only have real values (no imaginary) + + # Compute edge scalar features for m=0 + x_edge_0 = self.act(self.fc1_dist0(x_edge)) + + x_0 = x[:, 0 : self.mappingReduced.m_size[0]].contiguous() + x_0 = x_0.view(num_edges, -1) + + x_0 = self.fc1_m0(x_0) + x_0 = x_0 * x_edge_0 + x_0 = self.fc2_m0(x_0) + x_0 = x_0.view(num_edges, -1, self.sphere_channels) + + # Update the m=0 coefficients + x[:, 0 : self.mappingReduced.m_size[0]] = x_0 + + # Compute the values for the m > 0 coefficients + offset = self.mappingReduced.m_size[0] + for m in range(1, self.mmax + 1): + # Get the m order coefficients + x_m = x[ + :, offset : offset + 2 * self.mappingReduced.m_size[m] + ].contiguous() + x_m = x_m.view(num_edges, 2, -1) + # Perform SO(2) convolution + x_m = self.so2_conv[m - 1](x_m, x_edge) + x_m = x_m.view(num_edges, -1, self.sphere_channels) + x[:, offset : offset + 2 * self.mappingReduced.m_size[m]] = x_m + + offset = offset + 2 * self.mappingReduced.m_size[m] + + # Reshape the spherical harmonics based on l (degree) + return torch.einsum("nac,ab->nbc", x, self.mappingReduced.to_m) + + +class SO2Conv(torch.nn.Module): + """ + SO(2) Conv: Perform an SO(2) convolution + + Args: + m (int): Order of the spherical harmonic coefficients + sphere_channels (int): Number of spherical channels + hidden_channels (int): Number of hidden channels used during the SO(2) conv + edge_channels (int): Size of invariant edge embedding + lmax (int): degrees (l) for each resolution + mmax (int): orders (m) for each resolution + act (function): Non-linear activation function + """ + + def __init__( + self, + m: int, + sphere_channels: int, + hidden_channels: int, + edge_channels: int, + lmax: int, + mmax: int, + act, + ) -> None: + super().__init__() + self.hidden_channels = hidden_channels + self.lmax = lmax + self.mmax = mmax + self.sphere_channels = sphere_channels + self.m = m + self.act = act + + num_coefficents = 0 + if self.mmax >= m: + num_coefficents = self.lmax - m + 1 + + num_channels = num_coefficents * self.sphere_channels + + assert num_channels > 0 + + # Embedding function of the distance + self.fc1_dist = nn.Linear(edge_channels, 2 * self.hidden_channels) + + # Real weights of SO(2) convolution + self.fc1_r = nn.Linear(num_channels, self.hidden_channels, bias=False) + self.fc2_r = nn.Linear(self.hidden_channels, num_channels, bias=False) + + # Imaginary weights of SO(2) convolution + self.fc1_i = nn.Linear(num_channels, self.hidden_channels, bias=False) + self.fc2_i = nn.Linear(self.hidden_channels, num_channels, bias=False) + + def forward(self, x_m, x_edge) -> torch.Tensor: + # Compute edge scalar features + x_edge = self.act(self.fc1_dist(x_edge)) + x_edge = x_edge.view(-1, 2, self.hidden_channels) + + # Perform the complex weight multiplication + x_r = self.fc1_r(x_m) + x_r = x_r * x_edge[:, 0:1, :] + x_r = self.fc2_r(x_r) + + x_i = self.fc1_i(x_m) + x_i = x_i * x_edge[:, 1:2, :] + x_i = self.fc2_i(x_i) + + x_m_r = x_r[:, 0] - x_i[:, 1] + x_m_i = x_r[:, 1] + x_i[:, 0] + + return torch.stack((x_m_r, x_m_i), dim=1).contiguous() + + +class EdgeBlock(torch.nn.Module): + """ + Edge Block: Compute invariant edge representation from edge diatances and atomic numbers + + Args: + edge_channels (int): Size of invariant edge embedding + distance_expansion (func): Function used to compute distance embedding + max_num_elements (int): Maximum number of atomic numbers + act (function): Non-linear activation function + """ + + def __init__( + self, + edge_channels, + distance_expansion, + max_num_elements, + act, + ) -> None: + super().__init__() + self.in_channels = distance_expansion.num_output + self.distance_expansion = distance_expansion + self.act = act + self.edge_channels = edge_channels + self.max_num_elements = max_num_elements + + # Embedding function of the distance + self.fc1_dist = nn.Linear(self.in_channels, self.edge_channels) + + # Embedding function of the atomic numbers + self.source_embedding = nn.Embedding(self.max_num_elements, self.edge_channels) + self.target_embedding = nn.Embedding(self.max_num_elements, self.edge_channels) + nn.init.uniform_(self.source_embedding.weight.data, -0.001, 0.001) + nn.init.uniform_(self.target_embedding.weight.data, -0.001, 0.001) + + # Embedding function of the edge + self.fc1_edge_attr = nn.Linear( + self.edge_channels, + self.edge_channels, + ) + + def forward(self, edge_distance, source_element, target_element): + # Compute distance embedding + x_dist = self.distance_expansion(edge_distance) + x_dist = self.fc1_dist(x_dist) + + # Compute atomic number embeddings + source_embedding = self.source_embedding(source_element) + target_embedding = self.target_embedding(target_element) + + # Compute invariant edge embedding + x_edge = self.act(source_embedding + target_embedding + x_dist) + return self.act(self.fc1_edge_attr(x_edge)) + + +class EnergyBlock(torch.nn.Module): + """ + Energy Block: Output block computing the energy + + Args: + num_channels (int): Number of channels + num_sphere_samples (int): Number of samples used to approximate the integral on the sphere + act (function): Non-linear activation function + """ + + def __init__( + self, + num_channels: int, + num_sphere_samples: int, + act, + ) -> None: + super().__init__() + self.num_channels = num_channels + self.num_sphere_samples = num_sphere_samples + self.act = act + + self.fc1 = nn.Linear(self.num_channels, self.num_channels) + self.fc2 = nn.Linear(self.num_channels, self.num_channels) + self.fc3 = nn.Linear(self.num_channels, 1, bias=False) + + def forward(self, x_pt) -> torch.Tensor: + # x_pt are the values of the channels sampled at different points on the sphere + x_pt = self.act(self.fc1(x_pt)) + x_pt = self.act(self.fc2(x_pt)) + x_pt = self.fc3(x_pt) + x_pt = x_pt.view(-1, self.num_sphere_samples, 1) + return torch.sum(x_pt, dim=1) / self.num_sphere_samples + + +class ForceBlock(torch.nn.Module): + """ + Force Block: Output block computing the per atom forces + + Args: + num_channels (int): Number of channels + num_sphere_samples (int): Number of samples used to approximate the integral on the sphere + act (function): Non-linear activation function + """ + + def __init__( + self, + num_channels: int, + num_sphere_samples: int, + act, + ) -> None: + super().__init__() + self.num_channels = num_channels + self.num_sphere_samples = num_sphere_samples + self.act = act + + self.fc1 = nn.Linear(self.num_channels, self.num_channels) + self.fc2 = nn.Linear(self.num_channels, self.num_channels) + self.fc3 = nn.Linear(self.num_channels, 1, bias=False) + + def forward(self, x_pt, sphere_points) -> torch.Tensor: + # x_pt are the values of the channels sampled at different points on the sphere + x_pt = self.act(self.fc1(x_pt)) + x_pt = self.act(self.fc2(x_pt)) + x_pt = self.fc3(x_pt) + x_pt = x_pt.view(-1, self.num_sphere_samples, 1) + forces = x_pt * sphere_points.view(1, self.num_sphere_samples, 3) + return torch.sum(forces, dim=1) / self.num_sphere_samples diff --git a/src/fairchem/core/models/escn/so3_exportable.py b/src/fairchem/core/models/escn/so3_exportable.py new file mode 100644 index 000000000..a5189d22d --- /dev/null +++ b/src/fairchem/core/models/escn/so3_exportable.py @@ -0,0 +1,343 @@ +from __future__ import annotations + +import math +import os + +import torch + +try: + from e3nn import o3 + from e3nn.o3 import FromS2Grid, ToS2Grid +except ImportError: + pass + +# Borrowed from e3nn @ 0.4.0: +# https://github.com/e3nn/e3nn/blob/0.4.0/e3nn/o3/_wigner.py#L10 +# _Jd is a list of tensors of shape (2l+1, 2l+1) +__Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt")) +@torch.compiler.assume_constant_result +def get_jd() -> torch.Tensor: + return __Jd + + +# Borrowed from e3nn @ 0.4.0: +# https://github.com/e3nn/e3nn/blob/0.4.0/e3nn/o3/_wigner.py#L37 +# +# In 0.5.0, e3nn shifted to torch.matrix_exp which is significantly slower: +# https://github.com/e3nn/e3nn/blob/0.5.0/e3nn/o3/_wigner.py#L92 +def wigner_D( + lv: int, alpha: torch.Tensor, beta: torch.Tensor, gamma: torch.Tensor +) -> torch.Tensor: + _Jd = get_jd() + assert lv < len(_Jd), f"wigner D maximum l implemented is {len(_Jd) - 1}, send us an email to ask for more" + + alpha, beta, gamma = torch.broadcast_tensors(alpha, beta, gamma) + J = _Jd[lv].to(dtype=alpha.dtype, device=alpha.device) + Xa = _z_rot_mat(alpha, lv) + Xb = _z_rot_mat(beta, lv) + Xc = _z_rot_mat(gamma, lv) + return Xa @ J @ Xb @ J @ Xc + + +def _z_rot_mat(angle: torch.Tensor, lv: int) -> torch.Tensor: + shape, device, dtype = angle.shape, angle.device, angle.dtype + M = angle.new_zeros((*shape, 2 * lv + 1, 2 * lv + 1)) + inds = torch.arange(0, 2 * lv + 1, 1, device=device) + reversed_inds = torch.arange(2 * lv, -1, -1, device=device) + frequencies = torch.arange(lv, -lv - 1, -1, dtype=dtype, device=device) + M[..., inds, reversed_inds] = torch.sin(frequencies * angle[..., None]) + M[..., inds, inds] = torch.cos(frequencies * angle[..., None]) + return M + +def rotation_to_wigner( + edge_rot_mat: torch.Tensor, start_lmax: int, end_lmax: int +) -> torch.Tensor: + x = edge_rot_mat @ edge_rot_mat.new_tensor([0.0, 1.0, 0.0]) + alpha, beta = o3.xyz_to_angles(x) + R = ( + o3.angles_to_matrix(alpha, beta, torch.zeros_like(alpha)).transpose(-1, -2) + @ edge_rot_mat + ) + gamma = torch.atan2(R[..., 0, 2], R[..., 0, 0]) + + size = int((end_lmax + 1) ** 2) - int((start_lmax) ** 2) + wigner = torch.zeros(len(alpha), size, size, device=edge_rot_mat.device) + start = 0 + for lmax in range(start_lmax, end_lmax + 1): + block = wigner_D(lmax, alpha, beta, gamma) + end = start + block.size()[1] + wigner[:, start:end, start:end] = block + start = end + + return wigner.detach() + + +class CoefficientMapping(torch.nn.Module): + """ + Helper module for coefficients used to reshape l <--> m and to get coefficients of specific degree or order + + Args: + lmax_list (list:int): List of maximum degree of the spherical harmonics + mmax_list (list:int): List of maximum order of the spherical harmonics + use_rotate_inv_rescale (bool): Whether to pre-compute inverse rotation rescale matrices + """ + + def __init__( + self, + lmax_list, + mmax_list, + ): + super().__init__() + + self.lmax_list = lmax_list + self.mmax_list = mmax_list + self.num_resolutions = len(lmax_list) + + # TODO: remove this for loops here associated with lmax and mmax lists + assert len(self.lmax_list) == 1 + assert len(self.mmax_list) == 1 + + # Compute the degree (l) and order (m) for each entry of the embedding + l_harmonic = torch.tensor([]).long() + m_harmonic = torch.tensor([]).long() + m_complex = torch.tensor([]).long() + + self.res_size = torch.zeros([self.num_resolutions]).long().tolist() + + offset = 0 + for i in range(self.num_resolutions): + for l in range(self.lmax_list[i] + 1): + mmax = min(self.mmax_list[i], l) + m = torch.arange(-mmax, mmax + 1).long() + m_complex = torch.cat([m_complex, m], dim=0) + m_harmonic = torch.cat( + [m_harmonic, torch.abs(m).long()], dim=0 + ) + l_harmonic = torch.cat( + [l_harmonic, m.fill_(l).long()], dim=0 + ) + self.res_size[i] = len(l_harmonic) - offset + offset = len(l_harmonic) + + num_coefficients = len(l_harmonic) + # `self.to_m` moves m components from different L to contiguous index + to_m = torch.zeros([num_coefficients, num_coefficients]) + self.m_size = torch.zeros([max(self.mmax_list) + 1]).long().tolist() + + offset = 0 + for m in range(max(self.mmax_list) + 1): + idx_r, idx_i = self.complex_idx(m, -1, m_complex, l_harmonic) + + for idx_out, idx_in in enumerate(idx_r): + to_m[idx_out + offset, idx_in] = 1.0 + offset = offset + len(idx_r) + + self.m_size[m] = int(len(idx_r)) + + for idx_out, idx_in in enumerate(idx_i): + to_m[idx_out + offset, idx_in] = 1.0 + offset = offset + len(idx_i) + + to_m = to_m.detach() + + # save tensors and they will be moved to GPU + self.register_buffer("l_harmonic", l_harmonic) + self.register_buffer("m_harmonic", m_harmonic) + self.register_buffer("m_complex", m_complex) + self.register_buffer("to_m", to_m) + + self.pre_compute_coefficient_idx() + + + # Return mask containing coefficients of order m (real and imaginary parts) + def complex_idx(self, m, lmax, m_complex, l_harmonic): + """ + Add `m_complex` and `l_harmonic` to the input arguments + since we cannot use `self.m_complex`. + """ + if lmax == -1: + lmax = max(self.lmax_list) + + indices = torch.arange(len(l_harmonic)) + # Real part + mask_r = torch.bitwise_and( + l_harmonic.le(lmax), m_complex.eq(m) + ) + mask_idx_r = torch.masked_select(indices, mask_r) + + mask_idx_i = torch.tensor([]).long() + # Imaginary part + if m != 0: + mask_i = torch.bitwise_and( + l_harmonic.le(lmax), m_complex.eq(-m) + ) + mask_idx_i = torch.masked_select(indices, mask_i) + + return mask_idx_r, mask_idx_i + + + def pre_compute_coefficient_idx(self): + """ + Pre-compute the results of `coefficient_idx()` and access them with `prepare_coefficient_idx()` + """ + lmax = max(self.lmax_list) + for l in range(lmax + 1): + for m in range(lmax + 1): + mask = torch.bitwise_and( + self.l_harmonic.le(l), self.m_harmonic.le(m) + ) + indices = torch.arange(len(mask)) + mask_indices = torch.masked_select(indices, mask) + self.register_buffer(f"coefficient_idx_l{l}_m{m}", mask_indices) + + + def prepare_coefficient_idx(self): + """ + Construct a list of buffers + """ + lmax = max(self.lmax_list) + coefficient_idx_list = [] + for l in range(lmax + 1): + l_list = [] + for m in range(lmax + 1): + l_list.append(getattr(self, f"coefficient_idx_l{l}_m{m}", None)) + coefficient_idx_list.append(l_list) + return coefficient_idx_list + + + # Return mask containing coefficients less than or equal to degree (l) and order (m) + def coefficient_idx(self, lmax: int, mmax: int): + if lmax > max(self.lmax_list) or mmax > max(self.lmax_list): + mask = torch.bitwise_and( + self.l_harmonic.le(lmax), self.m_harmonic.le(mmax) + ) + indices = torch.arange(len(mask), device=mask.device) + return torch.masked_select(indices, mask) + else: + temp = self.prepare_coefficient_idx() + return temp[lmax][mmax] + + + def pre_compute_rotate_inv_rescale(self): + lmax = max(self.lmax_list) + for l in range(lmax + 1): + for m in range(lmax + 1): + mask_indices = self.coefficient_idx(l, m) + rotate_inv_rescale = torch.ones((1, int((l + 1)**2), int((l + 1)**2))) + for l_sub in range(l + 1): + if l_sub <= m: + continue + start_idx = l_sub ** 2 + length = 2 * l_sub + 1 + rescale_factor = math.sqrt(length / (2 * m + 1)) + rotate_inv_rescale[:, start_idx : (start_idx + length), start_idx : (start_idx + length)] = rescale_factor + rotate_inv_rescale = rotate_inv_rescale[:, :, mask_indices] + self.register_buffer(f"rotate_inv_rescale_l{l}_m{m}", rotate_inv_rescale) + + def __repr__(self): + return f"{self.__class__.__name__}(lmax_list={self.lmax_list}, mmax_list={self.mmax_list})" + + +class SO3_Grid(torch.nn.Module): + """ + Helper functions for grid representation of the irreps + + Args: + lmax (int): Maximum degree of the spherical harmonics + mmax (int): Maximum order of the spherical harmonics + """ + + def __init__( + self, + lmax: int, + mmax: int, + normalization: str = "integral", + resolution: int | None = None, + ): + super().__init__() + + self.lmax = lmax + self.mmax = mmax + self.lat_resolution = 2 * (self.lmax + 1) + if lmax == mmax: + self.long_resolution = 2 * (self.mmax + 1) + 1 + else: + self.long_resolution = 2 * (self.mmax) + 1 + if resolution is not None: + self.lat_resolution = resolution + self.long_resolution = resolution + + self.mapping = CoefficientMapping([self.lmax], [self.lmax]) + + device = "cpu" + + to_grid = ToS2Grid( + self.lmax, + (self.lat_resolution, self.long_resolution), + normalization=normalization, # normalization="integral", + device=device, + ) + to_grid_mat = torch.einsum("mbi, am -> bai", to_grid.shb, to_grid.sha).detach() + # rescale based on mmax + if lmax != mmax: + for lval in range(lmax + 1): + if lval <= mmax: + continue + start_idx = lval**2 + length = 2 * lval + 1 + rescale_factor = math.sqrt(length / (2 * mmax + 1)) + to_grid_mat[:, :, start_idx : (start_idx + length)] = ( + to_grid_mat[:, :, start_idx : (start_idx + length)] * rescale_factor + ) + to_grid_mat = to_grid_mat[ + :, :, self.mapping.coefficient_idx(self.lmax, self.mmax) + ] + + from_grid = FromS2Grid( + (self.lat_resolution, self.long_resolution), + self.lmax, + normalization=normalization, # normalization="integral", + device=device, + ) + from_grid_mat = torch.einsum( + "am, mbi -> bai", from_grid.sha, from_grid.shb + ).detach() + # rescale based on mmax + if lmax != mmax: + for lval in range(lmax + 1): + if lval <= mmax: + continue + start_idx = lval**2 + length = 2 * lval + 1 + rescale_factor = math.sqrt(length / (2 * mmax + 1)) + from_grid_mat[:, :, start_idx : (start_idx + length)] = ( + from_grid_mat[:, :, start_idx : (start_idx + length)] + * rescale_factor + ) + from_grid_mat = from_grid_mat[ + :, :, self.mapping.coefficient_idx(self.lmax, self.mmax) + ] + + # save tensors and they will be moved to GPU + self.register_buffer("to_grid_mat", to_grid_mat) + self.register_buffer("from_grid_mat", from_grid_mat) + + # Compute matrices to transform irreps to grid + def get_to_grid_mat(self, device=None): + return self.to_grid_mat + + # Compute matrices to transform grid to irreps + def get_from_grid_mat(self,device=None): + return self.from_grid_mat + + # Compute grid from irreps representation + def to_grid(self, embedding, lmax: int, mmax: int): + to_grid_mat = self.to_grid_mat[:, :, self.mapping.coefficient_idx(lmax, mmax)] + return torch.einsum("bai, zic -> zbac", to_grid_mat, embedding) + + # Compute irreps from grid representation + def from_grid(self, grid, lmax: int, mmax: int): + from_grid_mat = self.from_grid_mat[ + :, :, self.mapping.coefficient_idx(lmax, mmax) + ] + return torch.einsum("bai, zbac -> zic", from_grid_mat, grid) diff --git a/src/fairchem/core/preprocessing/atoms_to_graphs.py b/src/fairchem/core/preprocessing/atoms_to_graphs.py index 2283c40b8..473448de1 100644 --- a/src/fairchem/core/preprocessing/atoms_to_graphs.py +++ b/src/fairchem/core/preprocessing/atoms_to_graphs.py @@ -115,7 +115,6 @@ def _get_neighbors_pymatgen(self, atoms: ase.Atoms): _c_index, _n_index, _offsets, n_distance = struct.get_neighbor_list( r=self.radius, numerical_tol=0, exclude_self=True ) - _nonmax_idx = [] for i in range(len(atoms)): idx_i = (_c_index == i).nonzero()[0] @@ -148,6 +147,23 @@ def _reshape_features(self, c_index, n_index, n_distance, offsets): return edge_index, edge_distances, cell_offsets + def get_edge_distance_vec( + self, + pos, + edge_index, + cell, + cell_offsets, + ): + row, col = edge_index + distance_vectors = pos[row] - pos[col] + + # correct for pbc + cell = torch.repeat_interleave(cell, edge_index.shape[1], dim=0) + offsets = cell_offsets.float().view(-1, 1, 3).bmm(cell.float()).view(-1, 3) + distance_vectors += offsets + + return distance_vectors + def convert(self, atoms: ase.Atoms, sid=None): """Convert a single atomic structure to a graph. @@ -203,6 +219,8 @@ def convert(self, atoms: ase.Atoms, sid=None): data.edge_index = edge_index data.cell_offsets = cell_offsets + data.edge_distance_vec = self.get_edge_distance_vec(positions, edge_index, cell, cell_offsets) + del atoms_copy if self.r_energy: energy = atoms.get_potential_energy(apply_constraint=False) diff --git a/tests/core/models/test_escn_compiles.py b/tests/core/models/test_escn_compiles.py new file mode 100644 index 000000000..269433d4d --- /dev/null +++ b/tests/core/models/test_escn_compiles.py @@ -0,0 +1,349 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import logging +import os +import random + +import numpy as np +import pytest +import torch +from ase.io import read +from torch.export import Dim, export +from torch.nn.parallel.distributed import DistributedDataParallel + +from fairchem.core.common.registry import registry +from fairchem.core.common.test_utils import init_local_distributed_process_group +from fairchem.core.common.transforms import RandomRotate +from fairchem.core.common.utils import setup_imports +from fairchem.core.datasets import data_list_collater +from fairchem.core.models.escn import escn_exportable +from fairchem.core.models.escn.so3_exportable import ( + CoefficientMapping, + SO3_Grid, +) +from fairchem.core.models.scn.smearing import GaussianSmearing +from fairchem.core.preprocessing import AtomsToGraphs + +skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="skipping when no gpu") + + +def load_data(): + atoms = read( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "atoms.json"), + index=0, + format="json", + ) + a2g = AtomsToGraphs( + max_neigh=300, + radius=6.0, + r_edges=True, + r_fixed=True, + r_distances=True, + ) + data_list = a2g.convert_all([atoms]) + return data_list[0] + + +def load_escn_model(): + torch.manual_seed(4) + setup_imports() + return registry.get_model_class("escn")( + use_pbc = True, + use_pbc_single = False, + regress_forces = True, + max_neighbors = 300, + cutoff = 6.0, + max_num_elements = 90, + num_layers = 8, + lmax_list = [4], + mmax_list = [2], + sphere_channels = 128, + hidden_channels = 256, + edge_channels = 128, + num_sphere_samples = 128, + distance_function = "gaussian", + basis_width_scalar = 1.0, + distance_resolution = 0.02, + resolution = None, + ) + +def load_escn_exportable_model(): + torch.manual_seed(4) + setup_imports() + return registry.get_model_class("escn_export")( + regress_forces = True, + cutoff = 6.0, + max_num_elements = 90, + num_layers = 8, + lmax = 4, + mmax = 2, + sphere_channels = 128, + hidden_channels = 256, + edge_channels = 128, + num_sphere_samples = 128, + distance_function = "gaussian", + basis_width_scalar = 1.0, + distance_resolution = 0.02, + resolution = None, + ) + +def init(backend: str): + if not torch.distributed.is_initialized(): + init_local_distributed_process_group(backend=backend) + +class TestESCNCompiles: + def test_escn_baseline_cpu(self, tol=1e-8): + init("gloo") + data = load_data() + data_tg = data_list_collater([data]) + data_export = data_list_collater([data], to_dict=True) + + base_model = DistributedDataParallel(load_escn_model()) + export_model = DistributedDataParallel(load_escn_exportable_model()) + base_output = base_model(data_tg) + export_output = export_model(data_export) + torch.set_printoptions(precision=8) + assert torch.allclose(base_output["energy"], export_output["energy"], atol=tol) + assert torch.allclose(base_output["forces"].mean(0), export_output["forces"].mean(0), atol=tol) + + @skip_if_no_cuda + def test_escn_baseline_cuda(self, tol=1e-8): + init("nccl") + data = load_data() + data_tg = data_list_collater([data]).to("cuda") + data_export = data_list_collater([data], to_dict=True) + data_export_cu = {k:v.to("cuda") for k,v in data_export.items()} + + base_model = DistributedDataParallel(load_escn_model().cuda()) + export_model = DistributedDataParallel(load_escn_exportable_model().cuda()) + base_output = base_model(data_tg) + export_output = export_model(data_export_cu) + torch.set_printoptions(precision=8) + assert torch.allclose(base_output["energy"], export_output["energy"], atol=tol) + assert torch.allclose(base_output["forces"].mean(0), export_output["forces"].mean(0), atol=tol) + + def test_rotation_invariance(self) -> None: + random.seed(1) + data = load_data() + + # Sampling a random rotation within [-180, 180] for all axes. + transform = RandomRotate([-180, 180], [0, 1, 2]) + data_rotated, rot, inv_rot = transform(data.clone()) + assert not np.array_equal(data.pos, data_rotated.pos) + + # Pass it through the model. + batch = data_list_collater([data, data_rotated], to_dict=True) + model = load_escn_exportable_model() + model.eval() + out = model(batch) + + # Compare predicted energies and forces (after inv-rotation). + energies = out["energy"].detach() + np.testing.assert_almost_equal(energies[0], energies[1], decimal=7) + + forces = out["forces"].detach() + logging.info(forces) + np.testing.assert_array_almost_equal( + forces[: forces.shape[0] // 2], + torch.matmul(forces[forces.shape[0] // 2 :], inv_rot), + decimal=5, + ) + + def test_escn_so2_conv_exports_and_compiles(self, tol=1e-5) -> None: + torch._dynamo.config.assume_static_by_default = False + torch._dynamo.config.automatic_dynamic_shapes = True + inp1_dim0 = Dim("inp1_dim0") + inp1_dim1 = None + inp1_dim2 = None + inp2_dim0 = inp1_dim0 + inp2_dim1 = None + + dynamic_shapes1 = { + "x": {0: inp1_dim0, 1: inp1_dim1, 2: inp1_dim2}, + "x_edge": {0: inp2_dim0, 1: inp2_dim1}, + } + + lmax, mmax = 4, 2 + mappingReduced = escn_exportable.CoefficientMapping([lmax], [mmax]) + shpere_channels = 128 + edge_channels = 128 + args=(torch.rand(680, 19, shpere_channels), torch.rand(680, edge_channels)) + + so2 = escn_exportable.SO2Block( + sphere_channels=shpere_channels, + hidden_channels=128, + edge_channels=edge_channels, + lmax=lmax, + mmax=mmax, + act=torch.nn.SiLU(), + mappingReduced=mappingReduced + ) + prog = export(so2, args=args, dynamic_shapes=dynamic_shapes1) + export_out = prog.module()(*args) + regular_out = so2(*args) + assert torch.allclose(export_out, regular_out, atol=tol) + + compiled_model = torch.compile(so2, dynamic=True) + compiled_out = compiled_model(*args) + assert torch.allclose(compiled_out, regular_out, atol=tol) + + def test_escn_message_block_exports_and_compiles(self, tol=1e-5) -> None: + random.seed(1) + + sphere_channels = 128 + hidden_channels = 128 + edge_channels = 128 + lmax, mmax = 4, 2 + distance_expansion = GaussianSmearing(0.0, 8.0, int(8.0 / 0.02), 1.0) + SO3_grid = torch.nn.ModuleDict() + SO3_grid["lmax_lmax"] = SO3_Grid(lmax, lmax) + SO3_grid["lmax_mmax"] = SO3_Grid(lmax, mmax) + mappingReduced = CoefficientMapping([lmax], [mmax]) + message_block = escn_exportable.MessageBlock( + layer_idx = 0, + sphere_channels = sphere_channels, + hidden_channels = hidden_channels, + edge_channels = edge_channels, + lmax = lmax, + mmax = mmax, + distance_expansion = distance_expansion, + max_num_elements = 90, + SO3_grid = SO3_grid, + act = torch.nn.SiLU(), + mappingReduced = mappingReduced + ) + + # generate inputs + batch_sizes = [34] + num_coefs = 25 + num_edges = 2000 + wigner = torch.rand([num_edges, num_coefs, num_coefs]) + args = [] + for b in batch_sizes: + x = torch.rand([b, num_coefs, sphere_channels]) + atom_n = torch.randint(1, 90, (b,)) + edge_d = torch.rand([num_edges]) + edge_indx = torch.randint(0, b, (2, num_edges)) + args.append((x, atom_n, edge_d, edge_indx, wigner)) + + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.assume_static_by_default = False + torch._dynamo.config.automatic_dynamic_shapes = True + torch._dynamo.config.verbose = True + compiled_model = torch.compile(message_block, dynamic=True) + compiled_output = compiled_model(*args[0]) + + exported_prog = export(message_block, args=args[0]) + exported_output = exported_prog(*args[0]) + + regular_out = message_block(*args[0]) + assert torch.allclose(compiled_output, regular_out, atol=tol) + assert torch.allclose(exported_output, regular_out, atol=tol) + + def test_escn_layer_block_exports_and_compiles(self, tol=1e-5) -> None: + random.seed(1) + + sphere_channels = 128 + hidden_channels = 128 + edge_channels = 128 + lmax, mmax = 4, 2 + distance_expansion = GaussianSmearing(0.0, 8.0, int(8.0 / 0.02), 1.0) + SO3_grid = torch.nn.ModuleDict() + SO3_grid["lmax_lmax"] = SO3_Grid(lmax, lmax) + SO3_grid["lmax_mmax"] = SO3_Grid(lmax, mmax) + mappingReduced = CoefficientMapping([lmax], [mmax]) + layer_block = escn_exportable.LayerBlock( + layer_idx = 0, + sphere_channels = sphere_channels, + hidden_channels = hidden_channels, + edge_channels = edge_channels, + lmax = lmax, + mmax = mmax, + distance_expansion = distance_expansion, + max_num_elements = 90, + SO3_grid = SO3_grid, + act = torch.nn.SiLU(), + mappingReduced = mappingReduced + ) + + # generate inputs + batch_sizes = [34, 35, 35] + num_edges = [680, 700, 680] + num_coefs = 25 + run_args = [] + for b,edges in zip(batch_sizes, num_edges): + x = torch.rand([b, num_coefs, sphere_channels]) + atom_n = torch.randint(1, 90, (b,)) + edge_d = torch.rand([edges]) + edge_indx = torch.randint(0, b, (2, edges)) + wigner = torch.rand([edges, num_coefs, num_coefs]) + run_args.append((x, atom_n, edge_d, edge_indx, wigner)) + + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.assume_static_by_default = False + torch._dynamo.config.automatic_dynamic_shapes = True + torch._dynamo.config.verbose = True + + batch_dim = Dim("batch_dim") + edges_dim = Dim("edges_dim") + dynamic_shapes1 = { + "x": {0: batch_dim, 1: None, 2: None}, + "atomic_numbers": {0: batch_dim}, + "edge_distance": {0: edges_dim}, + "edge_index": {0: None, 1: edges_dim}, + "wigner": {0: edges_dim, 1: None, 2: None} + } + exported_prog = export(layer_block, args=run_args[0], dynamic_shapes=dynamic_shapes1) + for run_arg in run_args: + exported_output = exported_prog(*run_arg) + compiled_model = torch.compile(layer_block, dynamic=True) + compiled_output = compiled_model(*run_arg) + regular_out = layer_block(*run_arg) + assert torch.allclose(compiled_output, regular_out, atol=tol) + assert torch.allclose(exported_output, regular_out, atol=tol) + + def test_full_escn_compiles(self, tol=1e-5): + init("gloo") + data = load_data() + regular_data = data_list_collater([data]) + compile_data = data_list_collater([data], to_dict=True) + escn_model = DistributedDataParallel(load_escn_model()) + exportable_model = load_escn_exportable_model() + + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.assume_static_by_default = False + torch._dynamo.config.automatic_dynamic_shapes = True + compiled_model = torch.compile(exportable_model, dynamic=True) + output = compiled_model(compile_data) + expected_output = escn_model(regular_data) + assert torch.allclose(expected_output["energy"], output["energy"], atol=tol) + assert torch.allclose(expected_output["forces"].mean(0), output["forces"].mean(0), atol=tol) + + def test_full_escn_exports(self): + init("gloo") + data = load_data() + regular_data = data_list_collater([data]) + export_data = data_list_collater([data], to_dict=True) + escn_model = load_escn_model() + exportable_model = load_escn_exportable_model() + + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.assume_static_by_default = False + torch._dynamo.config.automatic_dynamic_shapes = True + # torch._logging.set_logs(dynamo = logging.INFO) + # torch._dynamo.reset() + # explained_output = torch._dynamo.explain(model)(*data) + # print(explained_output) + # TODO: add dynamic shapes + exported_prog = export(exportable_model, args=(export_data,)) + export_output = exported_prog(export_data) + expected_output = escn_model(regular_data) + assert torch.allclose(export_output["energy"], expected_output["energy"]) + assert torch.allclose(export_output["forces"].mean(0), expected_output["forces"].mean(0))