From f56b4407ec5bd996e45541d91b9c5969a9fa99aa Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Sun, 1 Sep 2024 20:16:52 -0700 Subject: [PATCH] message block compiles and exports --- .../core/models/escn/escn_exportable.py | 34 +++++------ src/fairchem/core/models/utils/so3_utils.py | 59 +++++++++---------- tests/core/models/test_escn_compiles.py | 28 ++++++--- 3 files changed, 61 insertions(+), 60 deletions(-) diff --git a/src/fairchem/core/models/escn/escn_exportable.py b/src/fairchem/core/models/escn/escn_exportable.py index 1f01fadd5..b13c8e5cd 100644 --- a/src/fairchem/core/models/escn/escn_exportable.py +++ b/src/fairchem/core/models/escn/escn_exportable.py @@ -25,6 +25,7 @@ CoefficientMapping, SO3_Grid, SO3_Rotation, + rotation_to_wigner ) from fairchem.core.models.escn.so3 import SO3_Embedding from fairchem.core.models.scn.sampling import CalcSpherePoints @@ -177,14 +178,6 @@ def __init__( self.SO3_grid = nn.ModuleDict() self.SO3_grid["lmax_lmax"] = SO3_Grid(self.lmax_list[0], self.lmax_list[0], resolution=resolution) self.SO3_grid["lmax_mmax"] = SO3_Grid(self.lmax_list[0], self.mmax_list[0], resolution=resolution) - # self.SO3_grid = nn.ModuleList() - # for lval in range(max(self.lmax_list) + 1): - # SO3_m_grid = nn.ModuleList() - # for m in range(max(self.lmax_list) + 1): - # SO3_m_grid.append(SO3_Grid(lval, m, resolution=resolution)) - - # self.SO3_grid.append(SO3_m_grid) - # import pdb;pdb.set_trace() self.mappingReduced = CoefficientMapping(self.lmax_list, self.mmax_list) # Initialize the blocks for each layer of the GNN @@ -257,9 +250,7 @@ def forward(self, data): edge_rot_mat = self._init_edge_rot_mat( data, graph.edge_index, graph.edge_distance_vec ) - - # Initialize the WignerD matrices and other values for spherical harmonic calculations - self.SO3_edge_rot = SO3_Rotation(edge_rot_mat, self.lmax_list[0]) + wigner = rotation_to_wigner(edge_rot_mat, 0, self.lmax_list[0]).detach() ############################################################### # Initialize node embeddings @@ -296,7 +287,7 @@ def forward(self, data): atomic_numbers, graph.edge_distance, graph.edge_index, - self.SO3_edge_rot, + wigner, ) # Residual layer for all layers past the first @@ -309,7 +300,7 @@ def forward(self, data): atomic_numbers, graph.edge_distance, graph.edge_index, - self.SO3_edge_rot, + wigner, ) x.embedding = x_message @@ -499,7 +490,7 @@ def forward( atomic_numbers: torch.Tensor, edge_distance: torch.Tensor, edge_index: torch.Tensor, - SO3_edge_rot: SO3_Rotation, + wigner: torch.Tensor, ) -> torch.Tensor: # Compute messages by performing message block x_message = self.message_block( @@ -507,7 +498,7 @@ def forward( atomic_numbers, edge_distance, edge_index, - SO3_edge_rot, + wigner, ) print(f"x_message: {x_message.mean()}") @@ -580,6 +571,7 @@ def __init__( self.mmax_list = mmax_list self.edge_channels = edge_channels self.mappingReduced = mappingReduced + self.out_mask = self.mappingReduced.coefficient_idx(self.lmax_list[0], self.mmax_list[0]) # Create edge scalar (invariant to rotations) features self.edge_block = EdgeBlock( @@ -615,7 +607,7 @@ def forward( atomic_numbers: torch.Tensor, edge_distance: torch.Tensor, edge_index: torch.Tensor, - SO3_edge_rot: SO3_Rotation, + wigner: torch.Tensor, ) -> torch.Tensor: ############################################################### # Compute messages @@ -635,8 +627,10 @@ def forward( x_target = x_target[edge_index[1, :]] # Rotate the irreps to align with the edge - x_source = SO3_edge_rot.rotate(x_source, self.lmax_list[0], self.mmax_list[0]) - x_target = SO3_edge_rot.rotate(x_target, self.lmax_list[0], self.mmax_list[0]) + x_source = torch.bmm(wigner[:, self.out_mask, :], x_source) + x_target = torch.bmm(wigner[:, self.out_mask, :], x_target) + # x_source = SO3_edge_rot.rotate(x_source, self.lmax_list[0], self.mmax_list[0]) + # x_target = SO3_edge_rot.rotate(x_target, self.lmax_list[0], self.mmax_list[0]) # Compute messages x_source = self.so2_block_source(x_source, x_edge) @@ -653,7 +647,9 @@ def forward( x_target = torch.einsum("bai,zbac->zic", from_grid_mat, x_grid) # Rotate back the irreps - x_target = SO3_edge_rot.rotate_inv(x_target, self.lmax_list[0], self.mmax_list[0]) + # x_target = SO3_edge_rot.rotate_inv(x_target, self.lmax_list[0], self.mmax_list[0]) + wigner_inv = torch.transpose(wigner, 1, 2).contiguous() + x_target = torch.bmm(wigner_inv[:, :, self.out_mask], x_target) # Compute the sum of the incoming neighboring messages for each target node new_embedding = torch.fill(x.clone(), 0) diff --git a/src/fairchem/core/models/utils/so3_utils.py b/src/fairchem/core/models/utils/so3_utils.py index 5845f4843..61e37744e 100644 --- a/src/fairchem/core/models/utils/so3_utils.py +++ b/src/fairchem/core/models/utils/so3_utils.py @@ -48,6 +48,28 @@ def _z_rot_mat(angle: torch.Tensor, lv: int) -> torch.Tensor: M[..., inds, inds] = torch.cos(frequencies * angle[..., None]) return M +def rotation_to_wigner( + edge_rot_mat: torch.Tensor, start_lmax: int, end_lmax: int +) -> torch.Tensor: + x = edge_rot_mat @ edge_rot_mat.new_tensor([0.0, 1.0, 0.0]) + alpha, beta = o3.xyz_to_angles(x) + R = ( + o3.angles_to_matrix(alpha, beta, torch.zeros_like(alpha)).transpose(-1, -2) + @ edge_rot_mat + ) + gamma = torch.atan2(R[..., 0, 2], R[..., 0, 0]) + + size = int((end_lmax + 1) ** 2) - int((start_lmax) ** 2) + wigner = torch.zeros(len(alpha), size, size, device=edge_rot_mat.device) + start = 0 + for lmax in range(start_lmax, end_lmax + 1): + block = wigner_D(lmax, alpha, beta, gamma) + end = start + block.size()[1] + wigner[:, start:end, start:end] = block + start = end + + return wigner.detach() + class CoefficientMapping(torch.nn.Module): """ @@ -63,13 +85,11 @@ def __init__( self, lmax_list, mmax_list, - use_rotate_inv_rescale=False ): super().__init__() self.lmax_list = lmax_list self.mmax_list = mmax_list - self.use_rotate_inv_rescale = use_rotate_inv_rescale self.num_resolutions = len(lmax_list) assert (len(self.lmax_list) == 1) and (len(self.mmax_list) == 1) @@ -121,13 +141,9 @@ def __init__( self.register_buffer('l_harmonic', l_harmonic) self.register_buffer('m_harmonic', m_harmonic) self.register_buffer('m_complex', m_complex) - # self.register_buffer('res_size', res_size) self.register_buffer('to_m', to_m) - # self.register_buffer('m_size', m_size) self.pre_compute_coefficient_idx() - if self.use_rotate_inv_rescale: - self.pre_compute_rotate_inv_rescale() # Return mask containing coefficients of order m (real and imaginary parts) @@ -215,30 +231,11 @@ def pre_compute_rotate_inv_rescale(self): rotate_inv_rescale[:, start_idx : (start_idx + length), start_idx : (start_idx + length)] = rescale_factor rotate_inv_rescale = rotate_inv_rescale[:, :, mask_indices] self.register_buffer('rotate_inv_rescale_l{}_m{}'.format(l, m), rotate_inv_rescale) - - - def prepare_rotate_inv_rescale(self): - lmax = max(self.lmax_list) - rotate_inv_rescale_list = [] - for l in range(lmax + 1): - l_list = [] - for m in range(lmax + 1): - l_list.append(getattr(self, 'rotate_inv_rescale_l{}_m{}'.format(l, m), None)) - rotate_inv_rescale_list.append(l_list) - return rotate_inv_rescale_list - - - # Return the re-scaling for rotating back to original frame - # this is required since we only use a subset of m components for SO(2) convolution - def get_rotate_inv_rescale(self, lmax, mmax): - temp = self.prepare_rotate_inv_rescale() - return temp[lmax][mmax] - def __repr__(self): return f"{self.__class__.__name__}(lmax_list={self.lmax_list}, mmax_list={self.mmax_list})" -class SO3_Rotation(torch.nn.Module): +class SO3_Rotation: """ Helper functions for Wigner-D rotations @@ -262,12 +259,10 @@ def __init__( self.wigner = self.wigner.detach() self.wigner_inv = self.wigner_inv.detach() - self.set_lmax(lmax) - - # Initialize coefficients for reshape l<-->m - def set_lmax(self, lmax) -> None: self.lmax = lmax - self.mapping = CoefficientMapping([self.lmax], [self.lmax], use_rotate_inv_rescale=True) + import pdb;pdb.set_trace() + self.mapping = CoefficientMapping([self.lmax], [self.lmax]) + # Rotate the embedding def rotate(self, embedding, out_lmax, out_mmax) -> torch.Tensor: @@ -286,7 +281,7 @@ def rotate_inv(self, embedding, in_lmax, in_mmax) -> torch.Tensor: def RotationToWignerDMatrix( self, edge_rot_mat: torch.Tensor, start_lmax: int, end_lmax: int ) -> torch.Tensor: - x = edge_rot_mat @ edge_rot_mat.new_tensor([0.0, 1.0, 0.0]) + x = edge_rot_mat[:,:,1] alpha, beta = o3.xyz_to_angles(x) R = ( o3.angles_to_matrix(alpha, beta, torch.zeros_like(alpha)).transpose(-1, -2) diff --git a/tests/core/models/test_escn_compiles.py b/tests/core/models/test_escn_compiles.py index 557d95cc7..c0458fbb9 100644 --- a/tests/core/models/test_escn_compiles.py +++ b/tests/core/models/test_escn_compiles.py @@ -29,7 +29,7 @@ from fairchem.core.models.scn.smearing import GaussianSmearing from fairchem.core.models.base import GraphModelMixin -from fairchem.core.models.utils.so3_utils import CoefficientMapping, SO3_Grid, SO3_Rotation +from fairchem.core.models.utils.so3_utils import CoefficientMapping, SO3_Grid, rotation_to_wigner from fairchem.core.models.escn import escn_exportable from torch.export import export @@ -175,7 +175,7 @@ def test_escn_so2_conv_exports_and_compiles(self, tol=1e-5) -> None: compiled_out = compiled_model(*args) assert torch.allclose(compiled_out, regular_out, atol=tol) - def test_escn_message_block_exports_and_compiles(self) -> None: + def test_escn_message_block_exports_and_compiles(self, tol=1e-5) -> None: random.seed(1) sphere_channels = 128 @@ -186,7 +186,7 @@ def test_escn_message_block_exports_and_compiles(self) -> None: SO3_grid = torch.nn.ModuleDict() SO3_grid["lmax_lmax"] = SO3_Grid(lmax, lmax) SO3_grid["lmax_mmax"] = SO3_Grid(lmax, mmax) - mappingReduced = escn_exportable.CoefficientMapping([lmax], [mmax]) + mappingReduced = CoefficientMapping([lmax], [mmax]) message_block = escn_exportable.MessageBlock( layer_idx = 0, sphere_channels = sphere_channels, @@ -208,7 +208,7 @@ def test_escn_message_block_exports_and_compiles(self) -> None: edge_rot_mat = full_model._init_edge_rot_mat( data, graph.edge_index, graph.edge_distance_vec ) - SO3_edge_rot = SO3_Rotation(edge_rot_mat, lmax) + wigner = rotation_to_wigner(edge_rot_mat, 0, lmax).detach() # generate inputs batch_sizes = [34] @@ -220,15 +220,25 @@ def test_escn_message_block_exports_and_compiles(self) -> None: atom_n = torch.randint(1, 90, (b,)) edge_d = torch.rand([num_edges]) edge_indx = torch.randint(0, b, (2, num_edges)) - args.append((x, atom_n, edge_d, edge_indx, SO3_edge_rot)) + args.append((x, atom_n, edge_d, edge_indx, wigner)) + + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.assume_static_by_default = False + torch._dynamo.config.automatic_dynamic_shapes = True + torch._dynamo.config.verbose = True + # torch._logging.set_logs(dynamo = logging.INFO) + # torch._dynamo.reset() + # explain_output = torch._dynamo.explain(message_block)(*args[0]) + # print(explain_output) + compiled_model = torch.compile(message_block, dynamic=True) + compiled_output = compiled_model(*args[0]) - # compiled_model = torch.compile(message_block, dynamic=True) exported_prog = export(message_block, args=args[0]) exported_output = exported_prog(*args[0]) - # output = message_block(*args) - # compiled_output = compiled_model(*args) - + regular_out = message_block(*args[0]) + assert torch.allclose(compiled_output, regular_out, atol=tol) + assert torch.allclose(exported_output, regular_out, atol=tol) def test_escn_compiles(self): init("gloo")