From 39572ddf29b8829a43e5d48880cadbc97db810bd Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Sat, 24 Aug 2024 12:42:07 -0700 Subject: [PATCH 01/25] update --- configs/s2ef/2M/base.yml | 4 +- src/fairchem/core/common/distutils.py | 11 + src/fairchem/core/common/utils.py | 8 +- src/fairchem/core/models/equiformer_v2/so3.py | 220 ++++++++++-------- src/fairchem/core/models/escn/so3.py | 2 +- src/fairchem/core/trainers/base_trainer.py | 10 + tests/core/models/atoms.json | 22 +- tests/core/models/test_eqv2_compiles.py | 129 ++++++++++ tests/core/models/test_escn_compiles.py | 96 ++++++++ 9 files changed, 401 insertions(+), 101 deletions(-) create mode 100644 tests/core/models/test_eqv2_compiles.py create mode 100644 tests/core/models/test_escn_compiles.py diff --git a/configs/s2ef/2M/base.yml b/configs/s2ef/2M/base.yml index cea1f121b0..69d8401bdd 100755 --- a/configs/s2ef/2M/base.yml +++ b/configs/s2ef/2M/base.yml @@ -3,7 +3,7 @@ trainer: ocp dataset: train: format: lmdb - src: data/s2ef/2M/train/ + src: /home/rgao/s2ef/s2ef/200k/train/ key_mapping: y: energy force: forces @@ -16,7 +16,7 @@ dataset: mean: 0 stdev: 2.887317180633545 val: - src: data/s2ef/all/val_id/ + src: /home/rgao/s2ef/s2ef/200k/train/ logger: wandb diff --git a/src/fairchem/core/common/distutils.py b/src/fairchem/core/common/distutils.py index f6bf88ccaf..8bf5b1d426 100644 --- a/src/fairchem/core/common/distutils.py +++ b/src/fairchem/core/common/distutils.py @@ -17,6 +17,7 @@ import torch.distributed as dist from fairchem.core.common.typing import none_throws +from torch.distributed.elastic.utils.distributed import get_free_port T = TypeVar("T") @@ -192,3 +193,13 @@ def gather_objects(data: T, group: dist.ProcessGroup = dist.group.WORLD) -> list output = [None for _ in range(get_world_size())] if is_master() else None dist.gather_object(data, output, group=group, dst=0) return output + +def init_local_distributed_process_group(backend="nccl"): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(get_free_port()) + dist.init_process_group( + rank=0, + world_size=1, + backend=backend, + timeout=timedelta(seconds=10), # setting up timeout for distributed collectives + ) diff --git a/src/fairchem/core/common/utils.py b/src/fairchem/core/common/utils.py index 955ea1e062..b0eff00ce6 100644 --- a/src/fairchem/core/common/utils.py +++ b/src/fairchem/core/common/utils.py @@ -763,6 +763,7 @@ def radius_graph_pbc( atom_distance=atom_distance_sqr, max_num_neighbors_threshold=max_num_neighbors_threshold, enforce_max_strictly=enforce_max_neighbors_strictly, + batch=data.batch, ) if not torch.all(mask_num_neighbors): @@ -786,6 +787,7 @@ def get_max_neighbors_mask( max_num_neighbors_threshold, degeneracy_tolerance: float = 0.01, enforce_max_strictly: bool = False, + batch=None, ): """ Give a mask that filters out edges so that each atom has at most @@ -808,14 +810,12 @@ def get_max_neighbors_mask( # Get number of neighbors # segment_coo assumes sorted index ones = index.new_ones(1).expand_as(index) - num_neighbors = segment_coo(ones, index, dim_size=num_atoms) + num_neighbors = scatter(ones, index, dim_size=num_atoms) max_num_neighbors = num_neighbors.max() num_neighbors_thresholded = num_neighbors.clamp(max=max_num_neighbors_threshold) # Get number of (thresholded) neighbors per image - image_indptr = torch.zeros(natoms.shape[0] + 1, device=device, dtype=torch.long) - image_indptr[1:] = torch.cumsum(natoms, dim=0) - num_neighbors_image = segment_csr(num_neighbors_thresholded, image_indptr) + num_neighbors_image = scatter(num_neighbors_thresholded, batch, dim_size=natoms.shape[0]) # If max_num_neighbors is below the threshold, return early if ( diff --git a/src/fairchem/core/models/equiformer_v2/so3.py b/src/fairchem/core/models/equiformer_v2/so3.py index a3d58586e0..909290e319 100644 --- a/src/fairchem/core/models/equiformer_v2/so3.py +++ b/src/fairchem/core/models/equiformer_v2/so3.py @@ -30,52 +30,56 @@ class CoefficientMappingModule(torch.nn.Module): """ - Helper module for coefficients used to reshape lval <--> m and to get coefficients of specific degree or order + Helper module for coefficients used to reshape l <--> m and to get coefficients of specific degree or order Args: lmax_list (list:int): List of maximum degree of the spherical harmonics mmax_list (list:int): List of maximum order of the spherical harmonics + use_rotate_inv_rescale (bool): Whether to pre-compute inverse rotation rescale matrices """ def __init__( self, - lmax_list: list[int], - mmax_list: list[int], + lmax_list, + mmax_list, + use_rotate_inv_rescale=False ): super().__init__() self.lmax_list = lmax_list self.mmax_list = mmax_list + self.use_rotate_inv_rescale = use_rotate_inv_rescale self.num_resolutions = len(lmax_list) - # Temporarily use `cpu` as device and this will be overwritten. - self.device = "cpu" - - # Compute the degree (lval) and order (m) for each entry of the embedding - l_harmonic = torch.tensor([], device=self.device).long() - m_harmonic = torch.tensor([], device=self.device).long() - m_complex = torch.tensor([], device=self.device).long() + assert (len(self.lmax_list) == 1) and (len(self.mmax_list) == 1) + + # Compute the degree (l) and order (m) for each entry of the embedding + l_harmonic = torch.tensor([]).long() + m_harmonic = torch.tensor([]).long() + m_complex = torch.tensor([]).long() - res_size = torch.zeros([self.num_resolutions], device=self.device).long() + res_size = torch.zeros([self.num_resolutions]).long() offset = 0 for i in range(self.num_resolutions): - for lval in range(self.lmax_list[i] + 1): - mmax = min(self.mmax_list[i], lval) - m = torch.arange(-mmax, mmax + 1, device=self.device).long() + for l in range(0, self.lmax_list[i] + 1): + mmax = min(self.mmax_list[i], l) + m = torch.arange(-mmax, mmax + 1).long() m_complex = torch.cat([m_complex, m], dim=0) - m_harmonic = torch.cat([m_harmonic, torch.abs(m).long()], dim=0) - l_harmonic = torch.cat([l_harmonic, m.fill_(lval).long()], dim=0) + m_harmonic = torch.cat( + [m_harmonic, torch.abs(m).long()], dim=0 + ) + l_harmonic = torch.cat( + [l_harmonic, m.fill_(l).long()], dim=0 + ) res_size[i] = len(l_harmonic) - offset offset = len(l_harmonic) num_coefficients = len(l_harmonic) # `self.to_m` moves m components from different L to contiguous index - to_m = torch.zeros([num_coefficients, num_coefficients], device=self.device) - m_size = torch.zeros([max(self.mmax_list) + 1], device=self.device).long() + to_m = torch.zeros([num_coefficients, num_coefficients]) + m_size = torch.zeros([max(self.mmax_list) + 1]).long() - # The following is implemented poorly - very slow. It only gets called - # a few times so haven't optimized. offset = 0 for m in range(max(self.mmax_list) + 1): idx_r, idx_i = self.complex_idx(m, -1, m_complex, l_harmonic) @@ -93,93 +97,124 @@ def __init__( to_m = to_m.detach() # save tensors and they will be moved to GPU - self.register_buffer("l_harmonic", l_harmonic) - self.register_buffer("m_harmonic", m_harmonic) - self.register_buffer("m_complex", m_complex) - self.register_buffer("res_size", res_size) - self.register_buffer("to_m", to_m) - self.register_buffer("m_size", m_size) - - # for caching the output of `coefficient_idx` - self.lmax_cache, self.mmax_cache = None, None - self.mask_indices_cache = None - self.rotate_inv_rescale_cache = None + self.register_buffer('l_harmonic', l_harmonic) + self.register_buffer('m_harmonic', m_harmonic) + self.register_buffer('m_complex', m_complex) + self.register_buffer('res_size', res_size) + self.register_buffer('to_m', to_m) + self.register_buffer('m_size', m_size) + + self.pre_compute_coefficient_idx() + if self.use_rotate_inv_rescale: + self.pre_compute_rotate_inv_rescale() + # Return mask containing coefficients of order m (real and imaginary parts) - def complex_idx(self, m: int, lmax: int, m_complex, l_harmonic): - """ - Add `m_complex` and `l_harmonic` to the input arguments - since we cannot use `self.m_complex`. - """ + def complex_idx(self, m, lmax, m_complex, l_harmonic): + ''' + Add `m_complex` and `l_harmonic` to the input arguments + since we cannot use `self.m_complex`. + ''' if lmax == -1: lmax = max(self.lmax_list) - indices = torch.arange(len(l_harmonic), device=self.device) + indices = torch.arange(len(l_harmonic)) # Real part - mask_r = torch.bitwise_and(l_harmonic.le(lmax), m_complex.eq(m)) + mask_r = torch.bitwise_and( + l_harmonic.le(lmax), m_complex.eq(m) + ) mask_idx_r = torch.masked_select(indices, mask_r) - mask_idx_i = torch.tensor([], device=self.device).long() + mask_idx_i = torch.tensor([]).long() # Imaginary part if m != 0: - mask_i = torch.bitwise_and(l_harmonic.le(lmax), m_complex.eq(-m)) + mask_i = torch.bitwise_and( + l_harmonic.le(lmax), m_complex.eq(-m) + ) mask_idx_i = torch.masked_select(indices, mask_i) return mask_idx_r, mask_idx_i - # Return mask containing coefficients less than or equal to degree (lval) and order (m) - def coefficient_idx(self, lmax: int, mmax: int): - if ( - (self.lmax_cache is not None) - and (self.mmax_cache is not None) - and (self.lmax_cache == lmax) - and (self.mmax_cache == mmax) - and self.mask_indices_cache is not None - ): - return self.mask_indices_cache - - mask = torch.bitwise_and(self.l_harmonic.le(lmax), self.m_harmonic.le(mmax)) - self.device = mask.device - indices = torch.arange(len(mask), device=self.device) - mask_indices = torch.masked_select(indices, mask) - self.lmax_cache, self.mmax_cache = lmax, mmax - self.mask_indices_cache = mask_indices - return self.mask_indices_cache + + def pre_compute_coefficient_idx(self): + ''' + Pre-compute the results of `coefficient_idx()` and access them with `prepare_coefficient_idx()` + ''' + lmax = max(self.lmax_list) + for l in range(lmax + 1): + for m in range(lmax + 1): + mask = torch.bitwise_and( + self.l_harmonic.le(l), self.m_harmonic.le(m) + ) + indices = torch.arange(len(mask)) + mask_indices = torch.masked_select(indices, mask) + self.register_buffer('coefficient_idx_l{}_m{}'.format(l, m), mask_indices) + + + def prepare_coefficient_idx(self): + ''' + Construct a list of buffers + ''' + lmax = max(self.lmax_list) + coefficient_idx_list = [] + for l in range(lmax + 1): + l_list = [] + for m in range(lmax + 1): + l_list.append(getattr(self, 'coefficient_idx_l{}_m{}'.format(l, m), None)) + coefficient_idx_list.append(l_list) + return coefficient_idx_list + + + # Return mask containing coefficients less than or equal to degree (l) and order (m) + def coefficient_idx(self, lmax, mmax): + if lmax > max(self.lmax_list) or mmax > max(self.lmax_list): + mask = torch.bitwise_and( + self.l_harmonic.le(lmax), self.m_harmonic.le(mmax) + ) + indices = torch.arange(len(mask), device=mask.device) + mask_indices = torch.masked_select(indices, mask) + return mask_indices + else: + temp = self.prepare_coefficient_idx() + return temp[lmax][mmax] + + + def pre_compute_rotate_inv_rescale(self): + lmax = max(self.lmax_list) + for l in range(lmax + 1): + for m in range(lmax + 1): + mask_indices = self.coefficient_idx(l, m) + rotate_inv_rescale = torch.ones((1, int((l + 1)**2), int((l + 1)**2))) + for l_sub in range(l + 1): + if l_sub <= m: + continue + start_idx = l_sub ** 2 + length = 2 * l_sub + 1 + rescale_factor = math.sqrt(length / (2 * m + 1)) + rotate_inv_rescale[:, start_idx : (start_idx + length), start_idx : (start_idx + length)] = rescale_factor + rotate_inv_rescale = rotate_inv_rescale[:, :, mask_indices] + self.register_buffer('rotate_inv_rescale_l{}_m{}'.format(l, m), rotate_inv_rescale) + + + def prepare_rotate_inv_rescale(self): + lmax = max(self.lmax_list) + rotate_inv_rescale_list = [] + for l in range(lmax + 1): + l_list = [] + for m in range(lmax + 1): + l_list.append(getattr(self, 'rotate_inv_rescale_l{}_m{}'.format(l, m), None)) + rotate_inv_rescale_list.append(l_list) + return rotate_inv_rescale_list + # Return the re-scaling for rotating back to original frame # this is required since we only use a subset of m components for SO(2) convolution - def get_rotate_inv_rescale(self, lmax: int, mmax: int): - if ( - (self.lmax_cache is not None) - and (self.mmax_cache is not None) - and (self.lmax_cache == lmax) - and (self.mmax_cache == mmax) - and self.rotate_inv_rescale_cache is not None - ): - return self.rotate_inv_rescale_cache - - if self.mask_indices_cache is None: - self.coefficient_idx(lmax, mmax) - - rotate_inv_rescale = torch.ones( - (1, (lmax + 1) ** 2, (lmax + 1) ** 2), device=self.device - ) - for lval in range(lmax + 1): - if lval <= mmax: - continue - start_idx = lval**2 - length = 2 * lval + 1 - rescale_factor = math.sqrt(length / (2 * mmax + 1)) - rotate_inv_rescale[ - :, - start_idx : (start_idx + length), - start_idx : (start_idx + length), - ] = rescale_factor - rotate_inv_rescale = rotate_inv_rescale[:, :, self.mask_indices_cache] - self.rotate_inv_rescale_cache = rotate_inv_rescale - return self.rotate_inv_rescale_cache - - def __repr__(self) -> str: + def get_rotate_inv_rescale(self, lmax, mmax): + temp = self.prepare_rotate_inv_rescale() + return temp[lmax][mmax] + + + def __repr__(self): return f"{self.__class__.__name__}(lmax_list={self.lmax_list}, mmax_list={self.mmax_list})" @@ -447,7 +482,7 @@ def __init__( ): super().__init__() self.lmax = lmax - self.mapping = CoefficientMappingModule([self.lmax], [self.lmax]) + self.mapping = CoefficientMappingModule([self.lmax], [self.lmax], use_rotate_inv_rescale=True) def set_wigner(self, rot_mat3x3): self.device, self.dtype = rot_mat3x3.device, rot_mat3x3.dtype @@ -482,7 +517,7 @@ def RotationToWignerDMatrix( ) gamma = torch.atan2(R[..., 0, 2], R[..., 0, 0]) - size = (end_lmax + 1) ** 2 - (start_lmax) ** 2 + size = int((end_lmax + 1) ** 2) - int((start_lmax) ** 2) wigner = torch.zeros(len(alpha), size, size, device=self.device) start = 0 for lmax in range(start_lmax, end_lmax + 1): @@ -511,6 +546,7 @@ def __init__( resolution: int | None = None, ): super().__init__() + self.lmax = lmax self.mmax = mmax self.lat_resolution = 2 * (self.lmax + 1) diff --git a/src/fairchem/core/models/escn/so3.py b/src/fairchem/core/models/escn/so3.py index 34f505d51e..fec2b85a3c 100644 --- a/src/fairchem/core/models/escn/so3.py +++ b/src/fairchem/core/models/escn/so3.py @@ -403,7 +403,7 @@ def RotationToWignerDMatrix( ) gamma = torch.atan2(R[..., 0, 2], R[..., 0, 0]) - size = (end_lmax + 1) ** 2 - (start_lmax) ** 2 + size = int((end_lmax + 1) ** 2) - int((start_lmax) ** 2) wigner = torch.zeros(len(alpha), size, size, device=self.device) start = 0 for lmax in range(start_lmax, end_lmax + 1): diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index 94becb924c..6918d45fc5 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -551,6 +551,16 @@ def load_model(self) -> None: device_ids=None if self.cpu else [self.device], ) + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.assume_static_by_default = False + torch._dynamo.config.automatic_dynamic_shapes = True + + if self.config["optim"].get("compiles"): + os.environ["TORCH_LOGS"] = "recompiles" + self.model = torch.compile(self.model, dynamic=True) + torch._dynamo.config.optimize_ddp = False + logging.info("torch compiled model") + @property def _unwrapped_model(self): module = self.model diff --git a/tests/core/models/atoms.json b/tests/core/models/atoms.json index 97c6c47304..d378cc6bdd 100644 --- a/tests/core/models/atoms.json +++ b/tests/core/models/atoms.json @@ -16,5 +16,23 @@ "tags": {"__ndarray__": [[34], "int64", [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]}, "unique_id": "77df5102462860280bfa6b622c880125", "user": "bwood"}, -"ids": [1], -"nextid": 2} + "2": { + "calculator": "unknown", + "calculator_parameters": {}, + "cell": {"array": {"__ndarray__": [[3, 3], "float64", [0.0, -8.07194878, 0.0, 6.93127032, 0.0, 0.08307657, 0.0, 0.0, 39.37850739]]}, "pbc": {"__ndarray__": [[3], "bool", [true, true, true]]}, "__ase_objtype__": "cell"}, + "constraints": [{"name": "FixAtoms", "kwargs": {"indices": [2, 3, 5, 6, 7, 9, 10, 11, 12, 14, 15, 17, 18, 19, 20, 22, 23, 24, 26, 27, 28, 30, 31, 33]}}], + "ctime": 20.460198850701047, + "energy": -135.66393572, + "forces": {"__ndarray__": [[35, 3], "float64", [0.01, 0.01, 0.01, 0.05011766, -0.01973735, 0.23846654, -0.12013861, -0.05240431, -0.22395961, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.10578597, 0.01361956, -0.05699137, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03172177, 0.00066391, -0.01049754, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.00908246, -0.09729627, 0.00726873, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.02260358, -0.09508909, -0.01036104, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03928853, -0.04423657, 0.04053315, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.02912151, 0.05899768, -0.01100117, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.09680946, 0.06950572, 0.05602877, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03057741, 0.10594487, -0.04712197, 0.0, 0.0, 0.0]]}, + "initial_charges": {"__ndarray__": [[35], "float64", [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]}, + "initial_magmoms": {"__ndarray__": [[35], "float64", [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]}, + "momenta": {"__ndarray__": [[35, 3], "float64", [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]}, + "mtime": 20.460198850701047, + "numbers": {"__ndarray__": [[35], "int64", [6, 6, 8, 13, 13, 13, 13, 13, 13, 13, 13, 29, 29, 29, 29, 29, 29, 29, 29, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34]]}, + "pbc": {"__ndarray__": [[3], "bool", [true, true, true]]}, + "positions": {"__ndarray__": [[35, 3], "float64", [-0.5, -0.5, -0.5, -0.3289066593614256, -3.0340615866893037, 27.073342845551938, -0.0750331499077992, -2.8712314914365584, 28.205836912191387, 6.2092629718957655, -4.771209055418616, 21.953210855443853, 3.8988395550000003, -0.735234665418617, 18.643976120697392, 1.636610785518665, -1.2302542698255066, 23.72823397486728, 1.5884161381042343, -4.771209055418616, 15.334741779736007, 2.7436278118957658, -6.789196250418616, 21.91167257044385, 0.433204395, -2.7532218604186167, 18.602437835697394, 5.33707967127947, -3.0430981333485136, 25.502246117362063, 5.054051298104235, -6.789196250418616, 15.376280064736006, 3.8988395550000003, -4.771209055418616, 18.643976120697392, 1.5884161381042343, -0.735234665418617, 15.334741779736007, 6.2092629718957655, -0.735234665418617, 21.953210855443853, 1.7024669335227842, -4.898430878701221, 24.462466125364735, 2.7436278118957658, -2.7532218604186167, 21.91167257044385, 0.433204395, -6.789196250418616, 18.602437835697394, 5.0596241087542175, -7.073912126493459, 24.329534869886448, 5.054051298104235, -2.7532218604186167, 15.376280064736006, 1.5841717747237825, -4.763794809025211, 17.789819163977032, 6.205018677828017, -0.7278204190252113, 14.563661393015645, 3.8945952609322516, -0.7278204190252113, 21.09905389955426, 6.2730609484910635, -5.008717107687484, 24.37936591790035, 5.049806934723782, -6.796610416092535, 17.831357448977034, 2.739383517828017, -2.7606360260925347, 14.522123108015645, 0.4289601009322512, -2.7606360260925347, 21.05751561455426, 2.7016609108638554, -7.122213699359126, 24.33216256212159, 5.058295592171984, -2.7458076140252117, 17.84351570962914, 2.747872175276218, -6.781782004025211, 14.534281368667754, 0.43744868906774886, -6.781782004025211, 21.069673874375603, 3.0271987649116516, -2.983072135599385, 24.66107410517354, 3.903083849067749, -4.778623221092535, 21.111212159375604, 1.5926604321719833, -0.7426488310925348, 17.801977424629143, 6.319541839318875, -0.99856463967624, 24.661108015400288, 6.213507335276218, -4.778623221092535, 14.575819653667754]]}, + "tags": {"__ndarray__": [[35], "int64", [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]}, + "unique_id": "78df5102462860280bfa6b622c880125", + "user": "bwood"}, +"ids": [1, 2], +"nextid": 3} diff --git a/tests/core/models/test_eqv2_compiles.py b/tests/core/models/test_eqv2_compiles.py new file mode 100644 index 0000000000..e27c757c73 --- /dev/null +++ b/tests/core/models/test_eqv2_compiles.py @@ -0,0 +1,129 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import copy +import io +import os + +import pytest +import requests +import torch +from ase.io import read +from torch.nn.parallel.distributed import DistributedDataParallel + +from fairchem.core.common.registry import registry +from fairchem.core.common.distutils import init_local_distributed_process_group +from fairchem.core.common.utils import load_state_dict, setup_imports +from fairchem.core.datasets import data_list_collater +from fairchem.core.preprocessing import AtomsToGraphs + + +def load_data(): + atoms = read( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "atoms.json"), + index=":", + format="json", + ) + a2g = AtomsToGraphs( + max_neigh=200, + radius=6, + r_edges=False, + r_fixed=True, + ) + data_list = a2g.convert_all(atoms) + return data_list + + +def load_model(): + torch.manual_seed(4) + setup_imports() + model = registry.get_model_class("equiformer_v2")( + use_pbc=True, + regress_forces=True, + otf_graph=True, + max_neighbors=20, + max_radius=12.0, + max_num_elements=90, + num_layers=8, + sphere_channels=128, + attn_hidden_channels=64, + num_heads=8, + attn_alpha_channels=64, + attn_value_channels=16, + ffn_hidden_channels=128, + norm_type="layer_norm_sh", + lmax_list=[4], + mmax_list=[2], + grid_resolution=18, + num_sphere_samples=128, + edge_channels=128, + use_atom_edge_embedding=True, + distance_function="gaussian", + num_distance_basis=512, + attn_activation="silu", + use_s2_act_attn=False, + ffn_activation="silu", + use_gate_act=False, + use_grid_mlp=True, + alpha_drop=0.1, + drop_path_rate=0.1, + proj_drop=0.0, + weight_init="uniform", + ) + return model + + +def init(backend="nccl"): + if not torch.distributed.is_initialized(): + init_local_distributed_process_group(backend=backend) + + +def expected_energy_forces(): + energy = torch.tensor([-0.0261]) + forces = torch.tensor([-0.0008, -0.0018, -0.0020]) + return energy, forces + + +class TestEQV2Compiles: + def eqv2_baseline_output(self, backend: str): + init(backend=backend) + data = load_data() + data = data_list_collater([data[0]])#.to("cuda") + model = load_model()#.cuda() + ddp_model = DistributedDataParallel(model) + return ddp_model(data) + + def test_baseline_cpu(self): + outputs = self.eqv2_baseline_output("gloo") + energy, forces_mean = outputs["energy"].detach().cpu(), outputs["forces"].mean(0).detach().cpu() + expected_energy, expected_forces = expected_energy_forces() + print(energy) + print(forces_mean) + assert torch.allclose(energy, expected_energy, atol=1e-4) + assert torch.allclose(forces_mean, expected_forces, atol=1e-4) + + def test_eqv2_compiles(self): + init() + data = load_data() + data0 = data_list_collater([data[0]]).to("cuda") + data1 = data_list_collater([data[1]]).to("cuda") + model = load_model().cuda() + ddp_model = DistributedDataParallel(model) + + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.assume_static_by_default = False + torch._dynamo.config.automatic_dynamic_shapes = True + # torch._dynamo.config.suppress_errors = True + + os.environ["TORCH_LOGS"] = "+dynamo,recompiles" + compiled_model = torch.compile(model, dynamic=True) + torch._dynamo.config.optimize_ddp = False + compiled_model(data0) + compiled_model(data1) + # import pdb; pdb.set_trace() diff --git a/tests/core/models/test_escn_compiles.py b/tests/core/models/test_escn_compiles.py new file mode 100644 index 0000000000..0f0c9fb323 --- /dev/null +++ b/tests/core/models/test_escn_compiles.py @@ -0,0 +1,96 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import copy +import io +import os + +import pytest +import requests +import torch +from ase.io import read +from torch.nn.parallel.distributed import DistributedDataParallel + +from fairchem.core.common.registry import registry +from fairchem.core.common.distutils import init_local_distributed_process_group +from fairchem.core.common.utils import load_state_dict, setup_imports +from fairchem.core.datasets import data_list_collater +from fairchem.core.preprocessing import AtomsToGraphs + + +def load_data(): + atoms = read( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "atoms.json"), + index=0, + format="json", + ) + a2g = AtomsToGraphs( + max_neigh=200, + radius=6, + r_edges=False, + r_fixed=True, + ) + data_list = a2g.convert_all([atoms]) + return data_list[0] + + +def load_model(): + torch.manual_seed(4) + setup_imports() + model = registry.get_model_class("escn")( + use_pbc = True, + use_pbc_single = False, + regress_forces = True, + otf_graph = True, + max_neighbors = 20, + cutoff = 8.0, + max_num_elements = 90, + num_layers = 8, + lmax_list = [4], + mmax_list = [2], + sphere_channels = 128, + hidden_channels = 256, + edge_channels = 128, + num_sphere_samples = 128, + distance_function = "gaussian", + basis_width_scalar = 1.0, + distance_resolution = 0.02, + show_timing_info = False, + resolution = None, + ) + return model + +@pytest.fixture(scope="session") +def init(): + init_local_distributed_process_group() + + +class TestESCNCompiles: + def escn_baseline(self, init): + data = load_data() + data = data_list_collater([data]).to("cuda") + model = load_model().cuda() + ddp_model = DistributedDataParallel(model) + return ddp_model(data) + + def test_escn_compiles(self, init): + data = load_data() + data = data_list_collater([data]).to("cuda") + model = load_model().cuda() + ddp_model = DistributedDataParallel(model) + + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.assume_static_by_default = False + torch._dynamo.config.automatic_dynamic_shapes = True + # torch._dynamo.config.suppress_errors = True + + os.environ["TORCH_LOGS"] = "+dynamo,recompiles" + compiled_model = torch.compile(model, dynamic=True) + torch._dynamo.config.optimize_ddp = False + compiled_model(data) From 211763288f732e5e44522c02261015f7a6765715 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Sat, 24 Aug 2024 12:52:25 -0700 Subject: [PATCH 02/25] update --- tests/core/models/test_eqv2_compiles.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/core/models/test_eqv2_compiles.py b/tests/core/models/test_eqv2_compiles.py index e27c757c73..3d87c71196 100644 --- a/tests/core/models/test_eqv2_compiles.py +++ b/tests/core/models/test_eqv2_compiles.py @@ -103,8 +103,6 @@ def test_baseline_cpu(self): outputs = self.eqv2_baseline_output("gloo") energy, forces_mean = outputs["energy"].detach().cpu(), outputs["forces"].mean(0).detach().cpu() expected_energy, expected_forces = expected_energy_forces() - print(energy) - print(forces_mean) assert torch.allclose(energy, expected_energy, atol=1e-4) assert torch.allclose(forces_mean, expected_forces, atol=1e-4) @@ -119,6 +117,7 @@ def test_eqv2_compiles(self): torch._dynamo.config.optimize_ddp = False torch._dynamo.config.assume_static_by_default = False torch._dynamo.config.automatic_dynamic_shapes = True + torch._dynamo.config.cache_size_limit = 1 # torch._dynamo.config.suppress_errors = True os.environ["TORCH_LOGS"] = "+dynamo,recompiles" From 873f6d18281a03818c66cc9faf93d12b65bab611 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Sun, 25 Aug 2024 15:52:37 -0700 Subject: [PATCH 03/25] export so2 --- .../core/models/equiformer_v2/so2_ops.py | 140 ++++++++++++++++++ src/fairchem/core/models/equiformer_v2/so3.py | 6 +- tests/core/models/test_eqv2_compiles.py | 66 +++++++++ 3 files changed, 209 insertions(+), 3 deletions(-) diff --git a/src/fairchem/core/models/equiformer_v2/so2_ops.py b/src/fairchem/core/models/equiformer_v2/so2_ops.py index 71666284e8..1848fae2cb 100644 --- a/src/fairchem/core/models/equiformer_v2/so2_ops.py +++ b/src/fairchem/core/models/equiformer_v2/so2_ops.py @@ -139,6 +139,7 @@ def __init__( self.rad_func = RadialFunction(self.edge_channels_list) def forward(self, x, x_edge): + num_edges = len(x_edge) out = [] @@ -352,3 +353,142 @@ def forward(self, x, x_edge): out_embedding._l_primary(self.mappingReduced) return out_embedding + +class SO2_Convolution_Exportable(torch.nn.Module): + """ + SO(2) Block: Perform SO(2) convolutions for all m (orders) + + Args: + sphere_channels (int): Number of spherical channels + m_output_channels (int): Number of output channels used during the SO(2) conv + lmax_list (list:int): List of degrees (l) for each resolution + mmax_list (list:int): List of orders (m) for each resolution + mappingReduced (CoefficientMappingModule): Used to extract a subset of m components + internal_weights (bool): If True, not using radial function to multiply inputs features + edge_channels_list (list:int): List of sizes of invariant edge embedding. For example, [input_channels, hidden_channels, hidden_channels]. + extra_m0_output_channels (int): If not None, return `out_embedding` (SO3_Embedding) and `extra_m0_features` (Tensor). + """ + + def __init__( + self, + sphere_channels: int, + m_output_channels: int, + lmax_list: list[int], + mmax_list: list[int], + mappingReduced, + internal_weights: bool = True, + edge_channels_list: list[int] | None = None, + extra_m0_output_channels: int | None = None, + ): + super().__init__() + self.sphere_channels = sphere_channels + self.m_output_channels = m_output_channels + self.lmax_list = lmax_list + self.mmax_list = mmax_list + self.mappingReduced = mappingReduced + self.num_resolutions = len(lmax_list) + self.internal_weights = internal_weights + self.edge_channels_list = copy.deepcopy(edge_channels_list) + self.extra_m0_output_channels = extra_m0_output_channels + + num_channels_rad = 0 # for radial function + + num_channels_m0 = 0 + for i in range(self.num_resolutions): + num_coefficients = self.lmax_list[i] + 1 + num_channels_m0 = num_channels_m0 + num_coefficients * self.sphere_channels + + # SO(2) convolution for m = 0 + m0_output_channels = self.m_output_channels * ( + num_channels_m0 // self.sphere_channels + ) + if self.extra_m0_output_channels is not None: + m0_output_channels = m0_output_channels + self.extra_m0_output_channels + self.fc_m0 = Linear(num_channels_m0, m0_output_channels) + num_channels_rad = num_channels_rad + self.fc_m0.in_features + + # SO(2) convolution for non-zero m + self.so2_m_conv = nn.ModuleList() + for m in range(1, max(self.mmax_list) + 1): + self.so2_m_conv.append( + SO2_m_Convolution( + m, + self.sphere_channels, + self.m_output_channels, + self.lmax_list, + self.mmax_list, + ) + ) + num_channels_rad = num_channels_rad + self.so2_m_conv[-1].fc.in_features + + # Embedding function of distance + self.rad_func = None + if not self.internal_weights: + assert self.edge_channels_list is not None + self.edge_channels_list.append(int(num_channels_rad)) + self.rad_func = RadialFunction(self.edge_channels_list) + + def forward(self, x_emb, x_edge): + # x_emb: [num_edges, num_sh_coefs, num_features] + # x_edge: [num_edges, num_edge_features] + + num_edges = x_edge.shape[0] + out = [] + # torch export does not inputs based on a buffered tensor + m_size = self.mappingReduced.m_size + + # Reshape the spherical harmonics based on m (order), equivalent to x._m_primary + x_emb = torch.einsum("nac, ba -> nbc", x_emb, self.mappingReduced.to_m) + + # radial function + if self.rad_func is not None: + x_edge = self.rad_func(x_edge) + offset_rad = 0 + + # Compute m=0 coefficients separately since they only have real values (no imaginary) + x_0 = x_emb.narrow(1, 0, m_size[0]) + x_0 = x_0.reshape(x_edge.shape[0], -1) + if self.rad_func is not None: + x_edge_0 = x_edge.narrow(1, 0, self.fc_m0.in_features) + x_0 = x_0 * x_edge_0 + x_0 = self.fc_m0(x_0) + + x_0_extra = None + # extract extra m0 features + if self.extra_m0_output_channels is not None: + x_0_extra = x_0.narrow(-1, 0, self.extra_m0_output_channels) + x_0 = x_0.narrow( + -1, + self.extra_m0_output_channels, + (self.fc_m0.out_features - self.extra_m0_output_channels), + ) + + x_0 = x_0.view(num_edges, -1, self.m_output_channels) + out.append(x_0) + offset_rad = offset_rad + self.fc_m0.in_features + + # Compute the values for the m > 0 coefficients + offset = m_size[0] + for m in range(1, max(self.mmax_list) + 1): + # Get the m order coefficients + x_m = x_emb.narrow(1, offset, 2 * m_size[m]) + x_m = x_m.reshape(num_edges, 2, -1) + + # Perform SO(2) convolution + if self.rad_func is not None: + x_edge_m = x_edge.narrow( + 1, offset_rad, self.so2_m_conv[m - 1].fc.in_features + ) + x_edge_m = x_edge_m.reshape( + num_edges, 1, self.so2_m_conv[m - 1].fc.in_features + ) + x_m = x_m * x_edge_m + x_m = self.so2_m_conv[m - 1](x_m) + x_m = x_m.view(num_edges, -1, self.m_output_channels) + out.append(x_m) + offset = offset + 2 * m_size[m] + offset_rad = offset_rad + self.so2_m_conv[m - 1].fc.in_features + + out = torch.cat(out, dim=1) + out = torch.einsum("nac, ab -> nbc", out, self.mappingReduced.to_m) + return out diff --git a/src/fairchem/core/models/equiformer_v2/so3.py b/src/fairchem/core/models/equiformer_v2/so3.py index 909290e319..61c25f9437 100644 --- a/src/fairchem/core/models/equiformer_v2/so3.py +++ b/src/fairchem/core/models/equiformer_v2/so3.py @@ -78,7 +78,7 @@ def __init__( num_coefficients = len(l_harmonic) # `self.to_m` moves m components from different L to contiguous index to_m = torch.zeros([num_coefficients, num_coefficients]) - m_size = torch.zeros([max(self.mmax_list) + 1]).long() + self.m_size = torch.zeros([max(self.mmax_list) + 1]).long().tolist() offset = 0 for m in range(max(self.mmax_list) + 1): @@ -88,7 +88,7 @@ def __init__( to_m[idx_out + offset, idx_in] = 1.0 offset = offset + len(idx_r) - m_size[m] = int(len(idx_r)) + self.m_size[m] = int(len(idx_r)) for idx_out, idx_in in enumerate(idx_i): to_m[idx_out + offset, idx_in] = 1.0 @@ -102,7 +102,7 @@ def __init__( self.register_buffer('m_complex', m_complex) self.register_buffer('res_size', res_size) self.register_buffer('to_m', to_m) - self.register_buffer('m_size', m_size) + # self.register_buffer('m_size', m_size) self.pre_compute_coefficient_idx() if self.use_rotate_inv_rescale: diff --git a/tests/core/models/test_eqv2_compiles.py b/tests/core/models/test_eqv2_compiles.py index 3d87c71196..d4e3390980 100644 --- a/tests/core/models/test_eqv2_compiles.py +++ b/tests/core/models/test_eqv2_compiles.py @@ -22,7 +22,13 @@ from fairchem.core.common.utils import load_state_dict, setup_imports from fairchem.core.datasets import data_list_collater from fairchem.core.preprocessing import AtomsToGraphs +from torch_geometric.data import Data, Batch +from fairchem.core.models.equiformer_v2.so3 import CoefficientMappingModule, SO3_Embedding +from fairchem.core.models.equiformer_v2.so2_ops import SO2_Convolution, SO2_Convolution_Exportable + +from torch.export import export +from torch.export import Dim def load_data(): atoms = read( @@ -90,6 +96,15 @@ def expected_energy_forces(): return energy, forces +def rand_input(natoms: int) -> BaseData: + data = Data(natoms=natoms, + pos=torch.rand(natoms, 3), + cell=torch.rand([1, 3, 3]), + atomic_numbers=torch.randint(1, 99, (1, 3, 3)) + ) + batch = Batch.from_data_list([data]) + return batch + class TestEQV2Compiles: def eqv2_baseline_output(self, backend: str): init(backend=backend) @@ -126,3 +141,54 @@ def test_eqv2_compiles(self): compiled_model(data0) compiled_model(data1) # import pdb; pdb.set_trace() + +class TestExportableEQV2: + def test_so2_conv_equivalent(self): + torch.manual_seed(4) + lmax, mmax = 4, 2 + sc, mc = 128, 128 + mappingReduced = CoefficientMappingModule([lmax], [mmax]) + + start_rng_state = torch.random.get_rng_state() + so2_export = SO2_Convolution_Exportable(sphere_channels=sc, m_output_channels=mc, lmax_list=[lmax], mmax_list=[mmax],mappingReduced=mappingReduced) + torch.random.set_rng_state(start_rng_state) + so2 = SO2_Convolution(sphere_channels=sc, m_output_channels=mc, lmax_list=[lmax], mmax_list=[mmax],mappingReduced=mappingReduced) + + inputs_tensor = (torch.rand(129, 19, 128), torch.rand(129, 856)) + inputs_embedding = SO3_Embedding(129, [lmax], 128, inputs_tensor[0].device, inputs_tensor[0].dtype) + inputs_embedding.set_embedding(inputs_tensor[0]) + assert torch.allclose(inputs_tensor[0], inputs_embedding.embedding) + output = so2(inputs_embedding, inputs_tensor[1]) + output_export = so2_export(*inputs_tensor) + assert torch.allclose(output.embedding, output_export) + + def test_so2_conv_exportable(self): + torch._dynamo.config.assume_static_by_default = False + torch._dynamo.config.automatic_dynamic_shapes = True + inp1_dim0 = Dim("inp1_dim0") + inp1_dim1 = None + inp1_dim2 = None + inp2_dim0 = inp1_dim0 + inp2_dim1 = Dim("inp2_dim1") + + dynamic_shapes1 = { + "x_emb": {0: inp1_dim0, 1: inp1_dim1, 2: inp1_dim2}, + "x_edge": {0: inp2_dim0, 1: inp2_dim1}, + } + + lmax, mmax = 4, 2 + mappingReduced = CoefficientMappingModule([lmax], [mmax]) + args=(torch.rand(129, 19, 128), torch.rand(129, 856)) + + so2 = SO2_Convolution_Exportable(sphere_channels=128, m_output_channels=128, lmax_list=[lmax], mmax_list=[mmax],mappingReduced=mappingReduced) + prog = export(so2, args=args, dynamic_shapes=dynamic_shapes1) + export_out = prog.module()(*args) + regular_out = so2(*args) + assert torch.allclose(export_out, regular_out) + + args2=(torch.rand(130, 19, 128), torch.rand(130, 856)) + export_out2 = prog.module()(*args2) + regular_out2 = so2(*args2) + assert torch.allclose(export_out2, regular_out2) + + From cdb44100017ea74639ffd34611c2ebdf6e7745c3 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Wed, 28 Aug 2024 21:56:12 -0700 Subject: [PATCH 04/25] move mappingReduced to member --- src/fairchem/core/models/escn/escn.py | 59 ++++++++------- src/fairchem/core/models/escn/so3.py | 8 +- tests/core/models/test_escn_compiles.py | 97 ++++++++++++++++++++++--- 3 files changed, 119 insertions(+), 45 deletions(-) diff --git a/src/fairchem/core/models/escn/escn.py b/src/fairchem/core/models/escn/escn.py index 6eb95947ae..0a6c7064d5 100644 --- a/src/fairchem/core/models/escn/escn.py +++ b/src/fairchem/core/models/escn/escn.py @@ -21,12 +21,12 @@ from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad from fairchem.core.models.base import BackboneInterface, GraphModelMixin, HeadInterface -from fairchem.core.models.escn.so3 import ( +from fairchem.core.models.utils.so3_utils import ( CoefficientMapping, - SO3_Embedding, SO3_Grid, SO3_Rotation, ) +from fairchem.core.models.escn.so3 import SO3_Embedding from fairchem.core.models.scn.sampling import CalcSpherePoints from fairchem.core.models.scn.smearing import ( GaussianSmearing, @@ -181,6 +181,8 @@ def __init__( self.SO3_grid.append(SO3_m_grid) + self.mappingReduced = CoefficientMapping(self.lmax_list, self.mmax_list) + # Initialize the blocks for each layer of the GNN self.layer_blocks = nn.ModuleList() for i in range(self.num_layers): @@ -195,6 +197,7 @@ def __init__( self.max_num_elements, self.SO3_grid, self.act, + self.mappingReduced ) self.layer_blocks.append(block) @@ -227,6 +230,7 @@ def __init__( ) self.sphharm_weights = nn.ParameterList(sphharm_weights) + @conditional_grad(torch.enable_grad()) def forward(self, data): device = data.pos.device @@ -279,9 +283,6 @@ def forward(self, data): offset = offset + self.sphere_channels offset_res = offset_res + int((self.lmax_list[i] + 1) ** 2) - # This can be expensive to compute (not implemented efficiently), so only do it once and pass it along to each layer - mappingReduced = CoefficientMapping(self.lmax_list, self.mmax_list, device) - ############################################################### # Update spherical node embeddings ############################################################### @@ -294,7 +295,6 @@ def forward(self, data): graph.edge_distance, graph.edge_index, self.SO3_edge_rot, - mappingReduced, ) # Residual layer for all layers past the first @@ -308,7 +308,6 @@ def forward(self, data): graph.edge_distance, graph.edge_index, self.SO3_edge_rot, - mappingReduced, ) # Sample the spherical channels (node embeddings) at evenly distributed points on the sphere. @@ -473,9 +472,6 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]: offset = offset + self.sphere_channels offset_res = offset_res + int((self.lmax_list[i] + 1) ** 2) - # This can be expensive to compute (not implemented efficiently), so only do it once and pass it along to each layer - mappingReduced = CoefficientMapping(self.lmax_list, self.mmax_list, device) - ############################################################### # Update spherical node embeddings ############################################################### @@ -488,7 +484,6 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]: graph.edge_distance, graph.edge_index, self.SO3_edge_rot, - mappingReduced, ) # Residual layer for all layers past the first @@ -502,7 +497,6 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]: graph.edge_distance, graph.edge_index, self.SO3_edge_rot, - mappingReduced, ) # Sample the spherical channels (node embeddings) at evenly distributed points on the sphere. @@ -599,6 +593,7 @@ def __init__( max_num_elements: int, SO3_grid: SO3_Grid, act, + mappingReduced, ) -> None: super().__init__() self.layer_idx = layer_idx @@ -609,6 +604,7 @@ def __init__( self.sphere_channels = sphere_channels self.sphere_channels_all = self.num_resolutions * self.sphere_channels self.SO3_grid = SO3_grid + self.mappingReduced = mappingReduced # Message block self.message_block = MessageBlock( @@ -622,6 +618,7 @@ def __init__( max_num_elements, self.SO3_grid, self.act, + self.mappingReduced ) # Non-linear point-wise comvolution for the aggregated messages @@ -644,7 +641,6 @@ def forward( edge_distance, edge_index, SO3_edge_rot, - mappingReduced, ): # Compute messages by performing message block x_message = self.message_block( @@ -653,7 +649,6 @@ def forward( edge_distance, edge_index, SO3_edge_rot, - mappingReduced, ) # Compute point-wise spherical non-linearity on aggregated messages @@ -705,6 +700,7 @@ def __init__( max_num_elements: int, SO3_grid: SO3_Grid, act, + mappingReduced, ) -> None: super().__init__() self.layer_idx = layer_idx @@ -716,6 +712,7 @@ def __init__( self.lmax_list = lmax_list self.mmax_list = mmax_list self.edge_channels = edge_channels + self.mappingReduced = mappingReduced # Create edge scalar (invariant to rotations) features self.edge_block = EdgeBlock( @@ -733,6 +730,7 @@ def __init__( self.lmax_list, self.mmax_list, self.act, + self.mappingReduced ) self.so2_block_target = SO2Block( self.sphere_channels, @@ -741,6 +739,7 @@ def __init__( self.lmax_list, self.mmax_list, self.act, + self.mappingReduced ) def forward( @@ -750,7 +749,6 @@ def forward( edge_distance, edge_index, SO3_edge_rot, - mappingReduced, ): ############################################################### # Compute messages @@ -775,17 +773,17 @@ def forward( x_target._rotate(SO3_edge_rot, self.lmax_list, self.mmax_list) # Compute messages - x_source = self.so2_block_source(x_source, x_edge, mappingReduced) - x_target = self.so2_block_target(x_target, x_edge, mappingReduced) + x_source = self.so2_block_source(x_source, x_edge) + x_target = self.so2_block_target(x_target, x_edge) # Add together the source and target results x_target.embedding = x_source.embedding + x_target.embedding # Point-wise spherical non-linearity - x_target._grid_act(self.SO3_grid, self.act, mappingReduced) + x_target._grid_act(self.SO3_grid, self.act, self.mappingReduced.res_size) # Rotate back the irreps - x_target._rotate_inv(SO3_edge_rot, mappingReduced) + x_target._rotate_inv(SO3_edge_rot, self.mappingReduced.res_size) # Compute the sum of the incoming neighboring messages for each target node x_target._reduce_edge(edge_index[1], len(x.embedding)) @@ -814,6 +812,7 @@ def __init__( lmax_list: list[int], mmax_list: list[int], act, + mappingReduced ) -> None: super().__init__() self.sphere_channels = sphere_channels @@ -822,6 +821,7 @@ def __init__( self.mmax_list = mmax_list self.num_resolutions: int = len(lmax_list) self.act = act + self.mappingReduced = mappingReduced num_channels_m0 = 0 for i in range(self.num_resolutions): @@ -849,21 +849,20 @@ def __init__( def forward( self, - x, - x_edge, - mappingReduced, + x: torch.Tensor, + x_edge: torch.Tensor, ): num_edges = len(x_edge) # Reshape the spherical harmonics based on m (order) - x._m_primary(mappingReduced) + x._m_primary(self.mappingReduced) # Compute m=0 coefficients separately since they only have real values (no imaginary) # Compute edge scalar features for m=0 x_edge_0 = self.act(self.fc1_dist0(x_edge)) - x_0 = x.embedding[:, 0 : mappingReduced.m_size[0]].contiguous() + x_0 = x.embedding[:, 0 : self.mappingReduced.m_size[0]].contiguous() x_0 = x_0.view(num_edges, -1) x_0 = self.fc1_m0(x_0) @@ -872,25 +871,25 @@ def forward( x_0 = x_0.view(num_edges, -1, x.num_channels) # Update the m=0 coefficients - x.embedding[:, 0 : mappingReduced.m_size[0]] = x_0 + x.embedding[:, 0 : self.mappingReduced.m_size[0]] = x_0 # Compute the values for the m > 0 coefficients - offset = mappingReduced.m_size[0] + offset = self.mappingReduced.m_size[0] for m in range(1, max(self.mmax_list) + 1): # Get the m order coefficients x_m = x.embedding[ - :, offset : offset + 2 * mappingReduced.m_size[m] + :, offset : offset + 2 * self.mappingReduced.m_size[m] ].contiguous() x_m = x_m.view(num_edges, 2, -1) # Perform SO(2) convolution x_m = self.so2_conv[m - 1](x_m, x_edge) x_m = x_m.view(num_edges, -1, x.num_channels) - x.embedding[:, offset : offset + 2 * mappingReduced.m_size[m]] = x_m + x.embedding[:, offset : offset + 2 * self.mappingReduced.m_size[m]] = x_m - offset = offset + 2 * mappingReduced.m_size[m] + offset = offset + 2 * self.mappingReduced.m_size[m] # Reshape the spherical harmonics based on l (degree) - x._l_primary(mappingReduced) + x._l_primary(self.mappingReduced) return x diff --git a/src/fairchem/core/models/escn/so3.py b/src/fairchem/core/models/escn/so3.py index fec2b85a3c..cb0f9d223b 100644 --- a/src/fairchem/core/models/escn/so3.py +++ b/src/fairchem/core/models/escn/so3.py @@ -241,12 +241,12 @@ def _rotate(self, SO3_rotation, lmax_list, mmax_list) -> None: self.set_lmax_mmax(lmax_list.copy(), mmax_list.copy()) # Rotate the embedding by the inverse of the rotation matrix - def _rotate_inv(self, SO3_rotation, mappingReduced) -> None: + def _rotate_inv(self, SO3_rotation, res_size) -> None: embedding_rotate = torch.tensor([], device=self.device, dtype=self.dtype) offset = 0 for i in range(self.num_resolutions): - num_coefficients = mappingReduced.res_size[i] + num_coefficients = res_size[i] embedding_i = self.embedding[:, offset : offset + num_coefficients] embedding_rotate = torch.cat( [ @@ -268,10 +268,10 @@ def _rotate_inv(self, SO3_rotation, mappingReduced) -> None: self.set_lmax_mmax(self.lmax_list, self.mmax_list) # Compute point-wise spherical non-linearity - def _grid_act(self, SO3_grid, act, mappingReduced) -> None: + def _grid_act(self, SO3_grid, act, res_size) -> None: offset = 0 for i in range(self.num_resolutions): - num_coefficients = mappingReduced.res_size[i] + num_coefficients = res_size[i] x_res = self.embedding[:, offset : offset + num_coefficients].contiguous() to_grid_mat = SO3_grid[self.lmax_list[i]][ diff --git a/tests/core/models/test_escn_compiles.py b/tests/core/models/test_escn_compiles.py index 0f0c9fb323..89ba8502ad 100644 --- a/tests/core/models/test_escn_compiles.py +++ b/tests/core/models/test_escn_compiles.py @@ -10,6 +10,9 @@ import copy import io import os +import random +import numpy as np +import logging import pytest import requests @@ -22,7 +25,13 @@ from fairchem.core.common.utils import load_state_dict, setup_imports from fairchem.core.datasets import data_list_collater from fairchem.core.preprocessing import AtomsToGraphs +from fairchem.core.common.transforms import RandomRotate +from fairchem.core.models.utils.so3_utils import CoefficientMapping +from fairchem.core.models.escn.escn import SO2Block + +from torch.export import export +from torch.export import Dim def load_data(): atoms = read( @@ -66,23 +75,33 @@ def load_model(): ) return model -@pytest.fixture(scope="session") -def init(): - init_local_distributed_process_group() +def expected_energy_forces(): + energy = torch.tensor([0.0001747000]) + forces = torch.tensor([-1.2720219900e-07, 8.2126695133e-07, -3.8776403244e-07]) + return energy, forces +def init(backend="nccl"): + if not torch.distributed.is_initialized(): + init_local_distributed_process_group(backend=backend) class TestESCNCompiles: - def escn_baseline(self, init): + def test_escn_baseline_cpu(self): + init("gloo") data = load_data() - data = data_list_collater([data]).to("cuda") - model = load_model().cuda() + data = data_list_collater([data]) + model = load_model() ddp_model = DistributedDataParallel(model) - return ddp_model(data) + output = ddp_model(data) + expected_energy, expected_forces = expected_energy_forces() + torch.set_printoptions(precision=8) + assert torch.allclose(output["energy"], expected_energy) + assert torch.allclose(output["forces"].mean(0), expected_forces) - def test_escn_compiles(self, init): + def test_escn_compiles(self): + init("gloo") data = load_data() - data = data_list_collater([data]).to("cuda") - model = load_model().cuda() + data = data_list_collater([data]) + model = load_model() ddp_model = DistributedDataParallel(model) torch._dynamo.config.optimize_ddp = False @@ -90,7 +109,63 @@ def test_escn_compiles(self, init): torch._dynamo.config.automatic_dynamic_shapes = True # torch._dynamo.config.suppress_errors = True - os.environ["TORCH_LOGS"] = "+dynamo,recompiles" + # os.environ["TORCH_LOGS"] = "+dynamo,recompiles" + # torch._logging.set_logs(dynamo = logging.INFO) + # os.environ["TORCHDYNAMO_VERBOSE"] = "1" + # os.environ["TORCHDYNAMO_REPRO_AFTER"]="dynamo" + # torch._dynamo.config.verbose = True compiled_model = torch.compile(model, dynamic=True) torch._dynamo.config.optimize_ddp = False + # torch._dynamo.explain(model)(data) + # assert False compiled_model(data) + + def test_rotation_invariance(self) -> None: + random.seed(1) + data = load_data() + + # Sampling a random rotation within [-180, 180] for all axes. + transform = RandomRotate([-180, 180], [0, 1, 2]) + data_rotated, rot, inv_rot = transform(data.clone()) + assert not np.array_equal(data.pos, data_rotated.pos) + + # Pass it through the model. + batch = data_list_collater([data, data_rotated]) + model = load_model() + model.eval() + out = model(batch) + + # Compare predicted energies and forces (after inv-rotation). + energies = out["energy"].detach() + np.testing.assert_almost_equal(energies[0], energies[1], decimal=5) + + forces = out["forces"].detach() + logging.info(forces) + np.testing.assert_array_almost_equal( + forces[: forces.shape[0] // 2], + torch.matmul(forces[forces.shape[0] // 2 :], inv_rot), + decimal=5, + ) + + # def test_escn_so2_conv_compiles(self) -> None: + # torch._dynamo.config.assume_static_by_default = False + # torch._dynamo.config.automatic_dynamic_shapes = True + # inp1_dim0 = Dim("inp1_dim0") + # inp1_dim1 = None + # inp1_dim2 = None + # inp2_dim0 = inp1_dim0 + # inp2_dim1 = Dim("inp2_dim1") + + # dynamic_shapes1 = { + # "x_emb": {0: inp1_dim0, 1: inp1_dim1, 2: inp1_dim2}, + # "x_edge": {0: inp2_dim0, 1: inp2_dim1}, + # } + + # lmax, mmax = 4, 2 + # mappingReduced = CoefficientMapping([lmax], [mmax]) + # edge_channels = 128 + # args=(torch.rand(680, 19, 128), torch.rand(680, edge_channels)) + + # so2 = SO2Block(sphere_channels=128, hidden_channels=128, edge_channels=edge_channels, lmax_list=[lmax], mmax_list=[mmax], act=torch.nn.SiLU()) + # prog = export(so2, args=args, dynamic_shapes=dynamic_shapes1) + # export_out = prog.module()(*args) From e4a426a91cf2e7d9cf488282aeb2c88bd75bce2c Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Wed, 28 Aug 2024 23:36:04 -0700 Subject: [PATCH 05/25] compile works, guard failures still --- src/fairchem/core/models/escn/escn.py | 28 ++++++++++++++----------- src/fairchem/core/models/escn/so3.py | 21 +++++++------------ tests/core/models/test_escn_compiles.py | 5 ++++- 3 files changed, 27 insertions(+), 27 deletions(-) diff --git a/src/fairchem/core/models/escn/escn.py b/src/fairchem/core/models/escn/escn.py index 0a6c7064d5..bedbf12c5f 100644 --- a/src/fairchem/core/models/escn/escn.py +++ b/src/fairchem/core/models/escn/escn.py @@ -173,14 +173,18 @@ def __init__( ) # Initialize the transformations between spherical and grid representations - self.SO3_grid = nn.ModuleList() - for lval in range(max(self.lmax_list) + 1): - SO3_m_grid = nn.ModuleList() - for m in range(max(self.lmax_list) + 1): - SO3_m_grid.append(SO3_Grid(lval, m, resolution=resolution)) - - self.SO3_grid.append(SO3_m_grid) - + assert self.num_resolutions == 1, "Only one resolution is supported" + self.SO3_grid = nn.ModuleDict() + self.SO3_grid["lmax_lmax"] = SO3_Grid(self.lmax_list[0], self.lmax_list[0], resolution=resolution) + self.SO3_grid["lmax_mmax"] = SO3_Grid(self.lmax_list[0], self.mmax_list[0], resolution=resolution) + # self.SO3_grid = nn.ModuleList() + # for lval in range(max(self.lmax_list) + 1): + # SO3_m_grid = nn.ModuleList() + # for m in range(max(self.lmax_list) + 1): + # SO3_m_grid.append(SO3_Grid(lval, m, resolution=resolution)) + + # self.SO3_grid.append(SO3_m_grid) + # import pdb;pdb.set_trace() self.mappingReduced = CoefficientMapping(self.lmax_list, self.mmax_list) # Initialize the blocks for each layer of the GNN @@ -655,8 +659,8 @@ def forward( max_lmax = max(self.lmax_list) # Project to grid - x_grid_message = x_message.to_grid(self.SO3_grid, lmax=max_lmax) - x_grid = x.to_grid(self.SO3_grid, lmax=max_lmax) + x_grid_message = x_message.to_grid(self.SO3_grid["lmax_lmax"]) + x_grid = x.to_grid(self.SO3_grid["lmax_lmax"]) x_grid = torch.cat([x_grid, x_grid_message], dim=3) # Perform point-wise convolution @@ -665,7 +669,7 @@ def forward( x_grid = self.fc3_sphere(x_grid) # Project back to spherical harmonic coefficients - x_message._from_grid(x_grid, self.SO3_grid, lmax=max_lmax) + x_message._from_grid(x_grid, self.SO3_grid["lmax_lmax"]) # Return aggregated messages return x_message @@ -780,7 +784,7 @@ def forward( x_target.embedding = x_source.embedding + x_target.embedding # Point-wise spherical non-linearity - x_target._grid_act(self.SO3_grid, self.act, self.mappingReduced.res_size) + x_target._grid_act(self.SO3_grid["lmax_mmax"], self.act, self.mappingReduced.res_size) # Rotate back the irreps x_target._rotate_inv(SO3_edge_rot, self.mappingReduced.res_size) diff --git a/src/fairchem/core/models/escn/so3.py b/src/fairchem/core/models/escn/so3.py index cb0f9d223b..0c674e2888 100644 --- a/src/fairchem/core/models/escn/so3.py +++ b/src/fairchem/core/models/escn/so3.py @@ -274,12 +274,8 @@ def _grid_act(self, SO3_grid, act, res_size) -> None: num_coefficients = res_size[i] x_res = self.embedding[:, offset : offset + num_coefficients].contiguous() - to_grid_mat = SO3_grid[self.lmax_list[i]][ - self.mmax_list[i] - ].get_to_grid_mat(self.device) - from_grid_mat = SO3_grid[self.lmax_list[i]][ - self.mmax_list[i] - ].get_from_grid_mat(self.device) + to_grid_mat = SO3_grid.get_to_grid_mat(self.device) + from_grid_mat = SO3_grid.get_from_grid_mat(self.device) x_grid = torch.einsum("bai,zic->zbac", to_grid_mat, x_res) x_grid = act(x_grid) @@ -293,8 +289,8 @@ def to_grid(self, SO3_grid, lmax: int = -1) -> torch.Tensor: if lmax == -1: lmax = max(self.lmax_list) - to_grid_mat_lmax = SO3_grid[lmax][lmax].get_to_grid_mat(self.device) - grid_mapping = SO3_grid[lmax][lmax].mapping + to_grid_mat_lmax = SO3_grid.get_to_grid_mat(self.device) + grid_mapping = SO3_grid.mapping offset = 0 x_grid = torch.tensor([], device=self.device) @@ -316,12 +312,9 @@ def to_grid(self, SO3_grid, lmax: int = -1) -> torch.Tensor: return x_grid # Compute irreps from grid representation - def _from_grid(self, x_grid, SO3_grid, lmax: int = -1) -> None: - if lmax == -1: - lmax = max(self.lmax_list) - - from_grid_mat_lmax = SO3_grid[lmax][lmax].get_from_grid_mat(self.device) - grid_mapping = SO3_grid[lmax][lmax].mapping + def _from_grid(self, x_grid, SO3_grid) -> None: + from_grid_mat_lmax = SO3_grid.get_from_grid_mat(self.device) + grid_mapping = SO3_grid.mapping offset = 0 offset_channel = 0 diff --git a/tests/core/models/test_escn_compiles.py b/tests/core/models/test_escn_compiles.py index 89ba8502ad..e6d0e0d8f4 100644 --- a/tests/core/models/test_escn_compiles.py +++ b/tests/core/models/test_escn_compiles.py @@ -118,7 +118,10 @@ def test_escn_compiles(self): torch._dynamo.config.optimize_ddp = False # torch._dynamo.explain(model)(data) # assert False - compiled_model(data) + output = compiled_model(data) + expected_energy, expected_forces = expected_energy_forces() + assert torch.allclose(output["energy"], expected_energy) + assert torch.allclose(output["forces"].mean(0), expected_forces) def test_rotation_invariance(self) -> None: random.seed(1) From 672e1fff547670724410a2effb2b5265263636d4 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Thu, 29 Aug 2024 10:45:26 -0700 Subject: [PATCH 06/25] escn so2 exports --- src/fairchem/core/models/escn/escn.py | 21 ++++----- tests/core/models/test_escn_compiles.py | 59 ++++++++++++++----------- 2 files changed, 44 insertions(+), 36 deletions(-) diff --git a/src/fairchem/core/models/escn/escn.py b/src/fairchem/core/models/escn/escn.py index bedbf12c5f..7b05278f21 100644 --- a/src/fairchem/core/models/escn/escn.py +++ b/src/fairchem/core/models/escn/escn.py @@ -777,8 +777,8 @@ def forward( x_target._rotate(SO3_edge_rot, self.lmax_list, self.mmax_list) # Compute messages - x_source = self.so2_block_source(x_source, x_edge) - x_target = self.so2_block_target(x_target, x_edge) + x_source.embedding = self.so2_block_source(x_source.embedding, x_edge) + x_target.embedding = self.so2_block_target(x_target.embedding, x_edge) # Add together the source and target results x_target.embedding = x_source.embedding + x_target.embedding @@ -859,41 +859,42 @@ def forward( num_edges = len(x_edge) # Reshape the spherical harmonics based on m (order) - x._m_primary(self.mappingReduced) + x = torch.einsum("nac,ba->nbc", x, self.mappingReduced.to_m) # Compute m=0 coefficients separately since they only have real values (no imaginary) # Compute edge scalar features for m=0 x_edge_0 = self.act(self.fc1_dist0(x_edge)) - x_0 = x.embedding[:, 0 : self.mappingReduced.m_size[0]].contiguous() + x_0 = x[:, 0 : self.mappingReduced.m_size[0]].contiguous() x_0 = x_0.view(num_edges, -1) x_0 = self.fc1_m0(x_0) x_0 = x_0 * x_edge_0 x_0 = self.fc2_m0(x_0) - x_0 = x_0.view(num_edges, -1, x.num_channels) + x_0 = x_0.view(num_edges, -1, self.sphere_channels) # Update the m=0 coefficients - x.embedding[:, 0 : self.mappingReduced.m_size[0]] = x_0 + x[:, 0 : self.mappingReduced.m_size[0]] = x_0 # Compute the values for the m > 0 coefficients offset = self.mappingReduced.m_size[0] for m in range(1, max(self.mmax_list) + 1): # Get the m order coefficients - x_m = x.embedding[ + x_m = x[ :, offset : offset + 2 * self.mappingReduced.m_size[m] ].contiguous() x_m = x_m.view(num_edges, 2, -1) # Perform SO(2) convolution x_m = self.so2_conv[m - 1](x_m, x_edge) - x_m = x_m.view(num_edges, -1, x.num_channels) - x.embedding[:, offset : offset + 2 * self.mappingReduced.m_size[m]] = x_m + x_m = x_m.view(num_edges, -1, self.sphere_channels) + x[:, offset : offset + 2 * self.mappingReduced.m_size[m]] = x_m offset = offset + 2 * self.mappingReduced.m_size[m] # Reshape the spherical harmonics based on l (degree) - x._l_primary(self.mappingReduced) + # x._l_primary(self.mappingReduced) + x = torch.einsum("nac,ab->nbc", x, self.mappingReduced.to_m) return x diff --git a/tests/core/models/test_escn_compiles.py b/tests/core/models/test_escn_compiles.py index e6d0e0d8f4..6a27632668 100644 --- a/tests/core/models/test_escn_compiles.py +++ b/tests/core/models/test_escn_compiles.py @@ -110,7 +110,7 @@ def test_escn_compiles(self): # torch._dynamo.config.suppress_errors = True # os.environ["TORCH_LOGS"] = "+dynamo,recompiles" - # torch._logging.set_logs(dynamo = logging.INFO) + torch._logging.set_logs(dynamo = logging.INFO) # os.environ["TORCHDYNAMO_VERBOSE"] = "1" # os.environ["TORCHDYNAMO_REPRO_AFTER"]="dynamo" # torch._dynamo.config.verbose = True @@ -118,10 +118,14 @@ def test_escn_compiles(self): torch._dynamo.config.optimize_ddp = False # torch._dynamo.explain(model)(data) # assert False + # torch._dynamo.reset() + # explain_output = torch._dynamo.explain(model)(data) + # print(explain_output) + output = compiled_model(data) - expected_energy, expected_forces = expected_energy_forces() - assert torch.allclose(output["energy"], expected_energy) - assert torch.allclose(output["forces"].mean(0), expected_forces) + # expected_energy, expected_forces = expected_energy_forces() + # assert torch.allclose(output["energy"], expected_energy) + # assert torch.allclose(output["forces"].mean(0), expected_forces) def test_rotation_invariance(self) -> None: random.seed(1) @@ -150,25 +154,28 @@ def test_rotation_invariance(self) -> None: decimal=5, ) - # def test_escn_so2_conv_compiles(self) -> None: - # torch._dynamo.config.assume_static_by_default = False - # torch._dynamo.config.automatic_dynamic_shapes = True - # inp1_dim0 = Dim("inp1_dim0") - # inp1_dim1 = None - # inp1_dim2 = None - # inp2_dim0 = inp1_dim0 - # inp2_dim1 = Dim("inp2_dim1") - - # dynamic_shapes1 = { - # "x_emb": {0: inp1_dim0, 1: inp1_dim1, 2: inp1_dim2}, - # "x_edge": {0: inp2_dim0, 1: inp2_dim1}, - # } - - # lmax, mmax = 4, 2 - # mappingReduced = CoefficientMapping([lmax], [mmax]) - # edge_channels = 128 - # args=(torch.rand(680, 19, 128), torch.rand(680, edge_channels)) - - # so2 = SO2Block(sphere_channels=128, hidden_channels=128, edge_channels=edge_channels, lmax_list=[lmax], mmax_list=[mmax], act=torch.nn.SiLU()) - # prog = export(so2, args=args, dynamic_shapes=dynamic_shapes1) - # export_out = prog.module()(*args) + def test_escn_so2_conv_exports(self) -> None: + torch._dynamo.config.assume_static_by_default = False + torch._dynamo.config.automatic_dynamic_shapes = True + inp1_dim0 = Dim("inp1_dim0") + inp1_dim1 = None + inp1_dim2 = None + inp2_dim0 = inp1_dim0 + inp2_dim1 = None + + dynamic_shapes1 = { + "x": {0: inp1_dim0, 1: inp1_dim1, 2: inp1_dim2}, + "x_edge": {0: inp2_dim0, 1: inp2_dim1}, + } + + lmax, mmax = 4, 2 + mappingReduced = CoefficientMapping([lmax], [mmax]) + shpere_channels = 128 + edge_channels = 128 + args=(torch.rand(680, 19, shpere_channels), torch.rand(680, edge_channels)) + + so2 = SO2Block(sphere_channels=shpere_channels, hidden_channels=128, edge_channels=edge_channels, lmax_list=[lmax], mmax_list=[mmax], act=torch.nn.SiLU(), mappingReduced=mappingReduced) + prog = export(so2, args=args, dynamic_shapes=dynamic_shapes1) + export_out = prog.module()(*args) + regular_out = so2(*args) + assert torch.allclose(export_out, regular_out) From d6ac5b795e44176eec3fe61323a9855fd2413a94 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Thu, 29 Aug 2024 11:01:01 -0700 Subject: [PATCH 07/25] add gpu test --- src/fairchem/core/models/escn/escn.py | 1 - tests/core/models/test_escn_compiles.py | 18 +++++++++++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/fairchem/core/models/escn/escn.py b/src/fairchem/core/models/escn/escn.py index 7b05278f21..c113cc1c82 100644 --- a/src/fairchem/core/models/escn/escn.py +++ b/src/fairchem/core/models/escn/escn.py @@ -757,7 +757,6 @@ def forward( ############################################################### # Compute messages ############################################################### - # Compute edge scalar features (invariant to rotations) # Uses atomic numbers and edge distance as inputs x_edge = self.edge_block( diff --git a/tests/core/models/test_escn_compiles.py b/tests/core/models/test_escn_compiles.py index 6a27632668..cd339b887e 100644 --- a/tests/core/models/test_escn_compiles.py +++ b/tests/core/models/test_escn_compiles.py @@ -33,6 +33,9 @@ from torch.export import export from torch.export import Dim +skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="skipping when no gpu") + + def load_data(): atoms = read( os.path.join(os.path.dirname(os.path.abspath(__file__)), "atoms.json"), @@ -86,7 +89,7 @@ def init(backend="nccl"): class TestESCNCompiles: def test_escn_baseline_cpu(self): - init("gloo") + init('gloo') data = load_data() data = data_list_collater([data]) model = load_model() @@ -97,6 +100,19 @@ def test_escn_baseline_cpu(self): assert torch.allclose(output["energy"], expected_energy) assert torch.allclose(output["forces"].mean(0), expected_forces) + @skip_if_no_cuda + def test_escn_baseline_cuda(self): + init('nccl') + data = load_data() + data = data_list_collater([data]).to("cuda") + model = load_model().cuda() + ddp_model = DistributedDataParallel(model) + output = ddp_model(data) + expected_energy, expected_forces = expected_energy_forces() + torch.set_printoptions(precision=8) + assert torch.allclose(output["energy"].cpu(), expected_energy) + assert torch.allclose(output["forces"].mean(0).cpu(), expected_forces) + def test_escn_compiles(self): init("gloo") data = load_data() From 401707376532633dc43f4a9aac76502395e5efd1 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Thu, 29 Aug 2024 11:05:58 -0700 Subject: [PATCH 08/25] pass cuda test --- tests/core/models/test_escn_compiles.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/core/models/test_escn_compiles.py b/tests/core/models/test_escn_compiles.py index cd339b887e..2caa7acc87 100644 --- a/tests/core/models/test_escn_compiles.py +++ b/tests/core/models/test_escn_compiles.py @@ -83,7 +83,12 @@ def expected_energy_forces(): forces = torch.tensor([-1.2720219900e-07, 8.2126695133e-07, -3.8776403244e-07]) return energy, forces -def init(backend="nccl"): +def expected_energy_forces_cuda(): + energy = torch.tensor([0.0001747000]) + forces = torch.tensor([-4.5273015559e-08, 9.0246174977e-07, -3.8560736471e-07]) + return energy, forces + +def init(backend: str): if not torch.distributed.is_initialized(): init_local_distributed_process_group(backend=backend) @@ -108,7 +113,7 @@ def test_escn_baseline_cuda(self): model = load_model().cuda() ddp_model = DistributedDataParallel(model) output = ddp_model(data) - expected_energy, expected_forces = expected_energy_forces() + expected_energy, expected_forces = expected_energy_forces_cuda() torch.set_printoptions(precision=8) assert torch.allclose(output["energy"].cpu(), expected_energy) assert torch.allclose(output["forces"].mean(0).cpu(), expected_forces) From 52651a6943047c90f5f7f84f38ea0913af4f53fb Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Thu, 29 Aug 2024 19:19:03 -0700 Subject: [PATCH 09/25] layer block --- src/fairchem/core/models/escn/escn.py | 91 ++++++++++++++----------- tests/core/models/test_escn_compiles.py | 12 ++-- 2 files changed, 57 insertions(+), 46 deletions(-) diff --git a/src/fairchem/core/models/escn/escn.py b/src/fairchem/core/models/escn/escn.py index c113cc1c82..26e4b86bcc 100644 --- a/src/fairchem/core/models/escn/escn.py +++ b/src/fairchem/core/models/escn/escn.py @@ -259,9 +259,7 @@ def forward(self, data): ) # Initialize the WignerD matrices and other values for spherical harmonic calculations - self.SO3_edge_rot = nn.ModuleList() - for i in range(self.num_resolutions): - self.SO3_edge_rot.append(SO3_Rotation(edge_rot_mat, self.lmax_list[i])) + self.SO3_edge_rot = SO3_Rotation(edge_rot_mat, self.lmax_list[0]) ############################################################### # Initialize node embeddings @@ -290,11 +288,11 @@ def forward(self, data): ############################################################### # Update spherical node embeddings ############################################################### - + x_message = x.embedding for i in range(self.num_layers): if i > 0: - x_message = self.layer_blocks[i]( - x, + x_message_new = self.layer_blocks[i]( + x_message, atomic_numbers, graph.edge_distance, graph.edge_index, @@ -302,17 +300,18 @@ def forward(self, data): ) # Residual layer for all layers past the first - x.embedding = x.embedding + x_message.embedding + x_xessage = x_message + x_message_new else: # No residual for the first layer - x = self.layer_blocks[i]( - x, + x_message = self.layer_blocks[i]( + x_message, atomic_numbers, graph.edge_distance, graph.edge_index, self.SO3_edge_rot, ) + x.embedding = x_message # Sample the spherical channels (node embeddings) at evenly distributed points on the sphere. # These values are fed into the output blocks. @@ -448,9 +447,7 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]: ) # Initialize the WignerD matrices and other values for spherical harmonic calculations - self.SO3_edge_rot = nn.ModuleList() - for i in range(self.num_resolutions): - self.SO3_edge_rot.append(SO3_Rotation(edge_rot_mat, self.lmax_list[i])) + self.SO3_edge_rot = SO3_Rotation(edge_rot_mat, self.lmax_list[0]) ############################################################### # Initialize node embeddings @@ -640,12 +637,12 @@ def __init__( def forward( self, - x, - atomic_numbers, - edge_distance, - edge_index, - SO3_edge_rot, - ): + x: torch.Tensor, + atomic_numbers: torch.Tensor, + edge_distance: torch.Tensor, + edge_index: torch.Tensor, + SO3_edge_rot: SO3_Rotation, + ) -> torch.Tensor: # Compute messages by performing message block x_message = self.message_block( x, @@ -654,25 +651,33 @@ def forward( edge_index, SO3_edge_rot, ) + print(f"x_message: {x_message.mean()}") # Compute point-wise spherical non-linearity on aggregated messages - max_lmax = max(self.lmax_list) # Project to grid - x_grid_message = x_message.to_grid(self.SO3_grid["lmax_lmax"]) - x_grid = x.to_grid(self.SO3_grid["lmax_lmax"]) + # x_grid_message = x_message.to_grid(self.SO3_grid["lmax_lmax"]) + to_grid_mat = self.SO3_grid["lmax_lmax"].to_grid_mat[:, :, self.SO3_grid["lmax_lmax"].mapping.coefficient_idx(self.lmax_list[0], self.lmax_list[0])] + x_grid_message = torch.einsum("bai,zic->zbac", to_grid_mat, x_message) + + # x_grid = x.to_grid(self.SO3_grid["lmax_lmax"]) + x_grid = torch.einsum("bai,zic->zbac", to_grid_mat, x) x_grid = torch.cat([x_grid, x_grid_message], dim=3) # Perform point-wise convolution x_grid = self.act(self.fc1_sphere(x_grid)) x_grid = self.act(self.fc2_sphere(x_grid)) x_grid = self.fc3_sphere(x_grid) + print(f"x_grid: {x_grid.mean()}") # Project back to spherical harmonic coefficients - x_message._from_grid(x_grid, self.SO3_grid["lmax_lmax"]) + # x_message._from_grid(x_grid, self.SO3_grid["lmax_lmax"]) + from_grid_mat = self.SO3_grid["lmax_lmax"].from_grid_mat[:, :, self.SO3_grid["lmax_lmax"].mapping.coefficient_idx(self.lmax_list[0], self.lmax_list[0])] + x_message_final = torch.einsum("bai,zbac->zic", from_grid_mat, x_grid) + print(f"x_message_final: {x_message_final.mean()}") # Return aggregated messages - return x_message + return x_message_final class MessageBlock(torch.nn.Module): @@ -748,12 +753,12 @@ def __init__( def forward( self, - x, - atomic_numbers, - edge_distance, - edge_index, - SO3_edge_rot, - ): + x: torch.Tensor, + atomic_numbers: torch.Tensor, + edge_distance: torch.Tensor, + edge_index: torch.Tensor, + SO3_edge_rot: SO3_Rotation, + ) -> torch.Tensor: ############################################################### # Compute messages ############################################################### @@ -768,30 +773,36 @@ def forward( # Copy embeddings for each edge's source and target nodes x_source = x.clone() x_target = x.clone() - x_source._expand_edge(edge_index[0, :]) - x_target._expand_edge(edge_index[1, :]) + x_source = x_source[edge_index[0, :]] + x_target = x_target[edge_index[1, :]] # Rotate the irreps to align with the edge - x_source._rotate(SO3_edge_rot, self.lmax_list, self.mmax_list) - x_target._rotate(SO3_edge_rot, self.lmax_list, self.mmax_list) + x_source = SO3_edge_rot.rotate(x_source, self.lmax_list[0], self.mmax_list[0]) + x_target = SO3_edge_rot.rotate(x_target, self.lmax_list[0], self.mmax_list[0]) # Compute messages - x_source.embedding = self.so2_block_source(x_source.embedding, x_edge) - x_target.embedding = self.so2_block_target(x_target.embedding, x_edge) + x_source = self.so2_block_source(x_source, x_edge) + x_target = self.so2_block_target(x_target, x_edge) # Add together the source and target results - x_target.embedding = x_source.embedding + x_target.embedding + x_target = x_source + x_target # Point-wise spherical non-linearity - x_target._grid_act(self.SO3_grid["lmax_mmax"], self.act, self.mappingReduced.res_size) + to_grid_mat = self.SO3_grid["lmax_mmax"].get_to_grid_mat() + from_grid_mat = self.SO3_grid["lmax_mmax"].get_from_grid_mat() + x_grid = torch.einsum("bai,zic->zbac", to_grid_mat, x_target) + x_grid = self.act(x_grid) + x_target = torch.einsum("bai,zbac->zic", from_grid_mat, x_grid) # Rotate back the irreps - x_target._rotate_inv(SO3_edge_rot, self.mappingReduced.res_size) + x_target = SO3_edge_rot.rotate_inv(x_target, self.lmax_list[0], self.mmax_list[0]) # Compute the sum of the incoming neighboring messages for each target node - x_target._reduce_edge(edge_index[1], len(x.embedding)) + new_embedding = torch.fill(x.clone(), 0) + new_embedding.index_add_(0, edge_index[1], x_target) + # x_target._reduce_edge(edge_index[1], len(x.embedding)) - return x_target + return new_embedding class SO2Block(torch.nn.Module): diff --git a/tests/core/models/test_escn_compiles.py b/tests/core/models/test_escn_compiles.py index 2caa7acc87..817a5e674b 100644 --- a/tests/core/models/test_escn_compiles.py +++ b/tests/core/models/test_escn_compiles.py @@ -93,7 +93,7 @@ def init(backend: str): init_local_distributed_process_group(backend=backend) class TestESCNCompiles: - def test_escn_baseline_cpu(self): + def test_escn_baseline_cpu(self, tol=1e-5): init('gloo') data = load_data() data = data_list_collater([data]) @@ -102,11 +102,11 @@ def test_escn_baseline_cpu(self): output = ddp_model(data) expected_energy, expected_forces = expected_energy_forces() torch.set_printoptions(precision=8) - assert torch.allclose(output["energy"], expected_energy) - assert torch.allclose(output["forces"].mean(0), expected_forces) + assert torch.allclose(output["energy"], expected_energy, atol=tol) + assert torch.allclose(output["forces"].mean(0), expected_forces, atol=tol) @skip_if_no_cuda - def test_escn_baseline_cuda(self): + def test_escn_baseline_cuda(self, tol=1e-5): init('nccl') data = load_data() data = data_list_collater([data]).to("cuda") @@ -115,8 +115,8 @@ def test_escn_baseline_cuda(self): output = ddp_model(data) expected_energy, expected_forces = expected_energy_forces_cuda() torch.set_printoptions(precision=8) - assert torch.allclose(output["energy"].cpu(), expected_energy) - assert torch.allclose(output["forces"].mean(0).cpu(), expected_forces) + assert torch.allclose(output["energy"].cpu(), expected_energy, atol=tol) + assert torch.allclose(output["forces"].mean(0).cpu(), expected_forces, atol=tol) def test_escn_compiles(self): init("gloo") From 3b5d0d0aa00f420f089e6208b7cc9a8d3f66e376 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Thu, 29 Aug 2024 19:36:23 -0700 Subject: [PATCH 10/25] switch to separate export file --- src/fairchem/core/models/escn/escn.py | 168 ++- .../core/models/escn/escn_exportable.py | 964 ++++++++++++++++++ src/fairchem/core/models/escn/so3.py | 31 +- src/fairchem/core/models/utils/Jd.pt | Bin 0 -> 21697 bytes src/fairchem/core/models/utils/so3_utils.py | 410 ++++++++ tests/core/models/test_escn_compiles.py | 54 +- 6 files changed, 1496 insertions(+), 131 deletions(-) create mode 100644 src/fairchem/core/models/escn/escn_exportable.py create mode 100644 src/fairchem/core/models/utils/Jd.pt create mode 100644 src/fairchem/core/models/utils/so3_utils.py diff --git a/src/fairchem/core/models/escn/escn.py b/src/fairchem/core/models/escn/escn.py index 26e4b86bcc..6eb95947ae 100644 --- a/src/fairchem/core/models/escn/escn.py +++ b/src/fairchem/core/models/escn/escn.py @@ -21,12 +21,12 @@ from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad from fairchem.core.models.base import BackboneInterface, GraphModelMixin, HeadInterface -from fairchem.core.models.utils.so3_utils import ( +from fairchem.core.models.escn.so3 import ( CoefficientMapping, + SO3_Embedding, SO3_Grid, SO3_Rotation, ) -from fairchem.core.models.escn.so3 import SO3_Embedding from fairchem.core.models.scn.sampling import CalcSpherePoints from fairchem.core.models.scn.smearing import ( GaussianSmearing, @@ -173,19 +173,13 @@ def __init__( ) # Initialize the transformations between spherical and grid representations - assert self.num_resolutions == 1, "Only one resolution is supported" - self.SO3_grid = nn.ModuleDict() - self.SO3_grid["lmax_lmax"] = SO3_Grid(self.lmax_list[0], self.lmax_list[0], resolution=resolution) - self.SO3_grid["lmax_mmax"] = SO3_Grid(self.lmax_list[0], self.mmax_list[0], resolution=resolution) - # self.SO3_grid = nn.ModuleList() - # for lval in range(max(self.lmax_list) + 1): - # SO3_m_grid = nn.ModuleList() - # for m in range(max(self.lmax_list) + 1): - # SO3_m_grid.append(SO3_Grid(lval, m, resolution=resolution)) - - # self.SO3_grid.append(SO3_m_grid) - # import pdb;pdb.set_trace() - self.mappingReduced = CoefficientMapping(self.lmax_list, self.mmax_list) + self.SO3_grid = nn.ModuleList() + for lval in range(max(self.lmax_list) + 1): + SO3_m_grid = nn.ModuleList() + for m in range(max(self.lmax_list) + 1): + SO3_m_grid.append(SO3_Grid(lval, m, resolution=resolution)) + + self.SO3_grid.append(SO3_m_grid) # Initialize the blocks for each layer of the GNN self.layer_blocks = nn.ModuleList() @@ -201,7 +195,6 @@ def __init__( self.max_num_elements, self.SO3_grid, self.act, - self.mappingReduced ) self.layer_blocks.append(block) @@ -234,7 +227,6 @@ def __init__( ) self.sphharm_weights = nn.ParameterList(sphharm_weights) - @conditional_grad(torch.enable_grad()) def forward(self, data): device = data.pos.device @@ -259,7 +251,9 @@ def forward(self, data): ) # Initialize the WignerD matrices and other values for spherical harmonic calculations - self.SO3_edge_rot = SO3_Rotation(edge_rot_mat, self.lmax_list[0]) + self.SO3_edge_rot = nn.ModuleList() + for i in range(self.num_resolutions): + self.SO3_edge_rot.append(SO3_Rotation(edge_rot_mat, self.lmax_list[i])) ############################################################### # Initialize node embeddings @@ -285,33 +279,37 @@ def forward(self, data): offset = offset + self.sphere_channels offset_res = offset_res + int((self.lmax_list[i] + 1) ** 2) + # This can be expensive to compute (not implemented efficiently), so only do it once and pass it along to each layer + mappingReduced = CoefficientMapping(self.lmax_list, self.mmax_list, device) + ############################################################### # Update spherical node embeddings ############################################################### - x_message = x.embedding + for i in range(self.num_layers): if i > 0: - x_message_new = self.layer_blocks[i]( - x_message, + x_message = self.layer_blocks[i]( + x, atomic_numbers, graph.edge_distance, graph.edge_index, self.SO3_edge_rot, + mappingReduced, ) # Residual layer for all layers past the first - x_xessage = x_message + x_message_new + x.embedding = x.embedding + x_message.embedding else: # No residual for the first layer - x_message = self.layer_blocks[i]( - x_message, + x = self.layer_blocks[i]( + x, atomic_numbers, graph.edge_distance, graph.edge_index, self.SO3_edge_rot, + mappingReduced, ) - x.embedding = x_message # Sample the spherical channels (node embeddings) at evenly distributed points on the sphere. # These values are fed into the output blocks. @@ -447,7 +445,9 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]: ) # Initialize the WignerD matrices and other values for spherical harmonic calculations - self.SO3_edge_rot = SO3_Rotation(edge_rot_mat, self.lmax_list[0]) + self.SO3_edge_rot = nn.ModuleList() + for i in range(self.num_resolutions): + self.SO3_edge_rot.append(SO3_Rotation(edge_rot_mat, self.lmax_list[i])) ############################################################### # Initialize node embeddings @@ -473,6 +473,9 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]: offset = offset + self.sphere_channels offset_res = offset_res + int((self.lmax_list[i] + 1) ** 2) + # This can be expensive to compute (not implemented efficiently), so only do it once and pass it along to each layer + mappingReduced = CoefficientMapping(self.lmax_list, self.mmax_list, device) + ############################################################### # Update spherical node embeddings ############################################################### @@ -485,6 +488,7 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]: graph.edge_distance, graph.edge_index, self.SO3_edge_rot, + mappingReduced, ) # Residual layer for all layers past the first @@ -498,6 +502,7 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]: graph.edge_distance, graph.edge_index, self.SO3_edge_rot, + mappingReduced, ) # Sample the spherical channels (node embeddings) at evenly distributed points on the sphere. @@ -594,7 +599,6 @@ def __init__( max_num_elements: int, SO3_grid: SO3_Grid, act, - mappingReduced, ) -> None: super().__init__() self.layer_idx = layer_idx @@ -605,7 +609,6 @@ def __init__( self.sphere_channels = sphere_channels self.sphere_channels_all = self.num_resolutions * self.sphere_channels self.SO3_grid = SO3_grid - self.mappingReduced = mappingReduced # Message block self.message_block = MessageBlock( @@ -619,7 +622,6 @@ def __init__( max_num_elements, self.SO3_grid, self.act, - self.mappingReduced ) # Non-linear point-wise comvolution for the aggregated messages @@ -637,12 +639,13 @@ def __init__( def forward( self, - x: torch.Tensor, - atomic_numbers: torch.Tensor, - edge_distance: torch.Tensor, - edge_index: torch.Tensor, - SO3_edge_rot: SO3_Rotation, - ) -> torch.Tensor: + x, + atomic_numbers, + edge_distance, + edge_index, + SO3_edge_rot, + mappingReduced, + ): # Compute messages by performing message block x_message = self.message_block( x, @@ -650,34 +653,27 @@ def forward( edge_distance, edge_index, SO3_edge_rot, + mappingReduced, ) - print(f"x_message: {x_message.mean()}") # Compute point-wise spherical non-linearity on aggregated messages + max_lmax = max(self.lmax_list) # Project to grid - # x_grid_message = x_message.to_grid(self.SO3_grid["lmax_lmax"]) - to_grid_mat = self.SO3_grid["lmax_lmax"].to_grid_mat[:, :, self.SO3_grid["lmax_lmax"].mapping.coefficient_idx(self.lmax_list[0], self.lmax_list[0])] - x_grid_message = torch.einsum("bai,zic->zbac", to_grid_mat, x_message) - - # x_grid = x.to_grid(self.SO3_grid["lmax_lmax"]) - x_grid = torch.einsum("bai,zic->zbac", to_grid_mat, x) + x_grid_message = x_message.to_grid(self.SO3_grid, lmax=max_lmax) + x_grid = x.to_grid(self.SO3_grid, lmax=max_lmax) x_grid = torch.cat([x_grid, x_grid_message], dim=3) # Perform point-wise convolution x_grid = self.act(self.fc1_sphere(x_grid)) x_grid = self.act(self.fc2_sphere(x_grid)) x_grid = self.fc3_sphere(x_grid) - print(f"x_grid: {x_grid.mean()}") # Project back to spherical harmonic coefficients - # x_message._from_grid(x_grid, self.SO3_grid["lmax_lmax"]) - from_grid_mat = self.SO3_grid["lmax_lmax"].from_grid_mat[:, :, self.SO3_grid["lmax_lmax"].mapping.coefficient_idx(self.lmax_list[0], self.lmax_list[0])] - x_message_final = torch.einsum("bai,zbac->zic", from_grid_mat, x_grid) + x_message._from_grid(x_grid, self.SO3_grid, lmax=max_lmax) - print(f"x_message_final: {x_message_final.mean()}") # Return aggregated messages - return x_message_final + return x_message class MessageBlock(torch.nn.Module): @@ -709,7 +705,6 @@ def __init__( max_num_elements: int, SO3_grid: SO3_Grid, act, - mappingReduced, ) -> None: super().__init__() self.layer_idx = layer_idx @@ -721,7 +716,6 @@ def __init__( self.lmax_list = lmax_list self.mmax_list = mmax_list self.edge_channels = edge_channels - self.mappingReduced = mappingReduced # Create edge scalar (invariant to rotations) features self.edge_block = EdgeBlock( @@ -739,7 +733,6 @@ def __init__( self.lmax_list, self.mmax_list, self.act, - self.mappingReduced ) self.so2_block_target = SO2Block( self.sphere_channels, @@ -748,20 +741,21 @@ def __init__( self.lmax_list, self.mmax_list, self.act, - self.mappingReduced ) def forward( self, - x: torch.Tensor, - atomic_numbers: torch.Tensor, - edge_distance: torch.Tensor, - edge_index: torch.Tensor, - SO3_edge_rot: SO3_Rotation, - ) -> torch.Tensor: + x, + atomic_numbers, + edge_distance, + edge_index, + SO3_edge_rot, + mappingReduced, + ): ############################################################### # Compute messages ############################################################### + # Compute edge scalar features (invariant to rotations) # Uses atomic numbers and edge distance as inputs x_edge = self.edge_block( @@ -773,36 +767,30 @@ def forward( # Copy embeddings for each edge's source and target nodes x_source = x.clone() x_target = x.clone() - x_source = x_source[edge_index[0, :]] - x_target = x_target[edge_index[1, :]] + x_source._expand_edge(edge_index[0, :]) + x_target._expand_edge(edge_index[1, :]) # Rotate the irreps to align with the edge - x_source = SO3_edge_rot.rotate(x_source, self.lmax_list[0], self.mmax_list[0]) - x_target = SO3_edge_rot.rotate(x_target, self.lmax_list[0], self.mmax_list[0]) + x_source._rotate(SO3_edge_rot, self.lmax_list, self.mmax_list) + x_target._rotate(SO3_edge_rot, self.lmax_list, self.mmax_list) # Compute messages - x_source = self.so2_block_source(x_source, x_edge) - x_target = self.so2_block_target(x_target, x_edge) + x_source = self.so2_block_source(x_source, x_edge, mappingReduced) + x_target = self.so2_block_target(x_target, x_edge, mappingReduced) # Add together the source and target results - x_target = x_source + x_target + x_target.embedding = x_source.embedding + x_target.embedding # Point-wise spherical non-linearity - to_grid_mat = self.SO3_grid["lmax_mmax"].get_to_grid_mat() - from_grid_mat = self.SO3_grid["lmax_mmax"].get_from_grid_mat() - x_grid = torch.einsum("bai,zic->zbac", to_grid_mat, x_target) - x_grid = self.act(x_grid) - x_target = torch.einsum("bai,zbac->zic", from_grid_mat, x_grid) + x_target._grid_act(self.SO3_grid, self.act, mappingReduced) # Rotate back the irreps - x_target = SO3_edge_rot.rotate_inv(x_target, self.lmax_list[0], self.mmax_list[0]) + x_target._rotate_inv(SO3_edge_rot, mappingReduced) # Compute the sum of the incoming neighboring messages for each target node - new_embedding = torch.fill(x.clone(), 0) - new_embedding.index_add_(0, edge_index[1], x_target) - # x_target._reduce_edge(edge_index[1], len(x.embedding)) + x_target._reduce_edge(edge_index[1], len(x.embedding)) - return new_embedding + return x_target class SO2Block(torch.nn.Module): @@ -826,7 +814,6 @@ def __init__( lmax_list: list[int], mmax_list: list[int], act, - mappingReduced ) -> None: super().__init__() self.sphere_channels = sphere_channels @@ -835,7 +822,6 @@ def __init__( self.mmax_list = mmax_list self.num_resolutions: int = len(lmax_list) self.act = act - self.mappingReduced = mappingReduced num_channels_m0 = 0 for i in range(self.num_resolutions): @@ -863,48 +849,48 @@ def __init__( def forward( self, - x: torch.Tensor, - x_edge: torch.Tensor, + x, + x_edge, + mappingReduced, ): num_edges = len(x_edge) # Reshape the spherical harmonics based on m (order) - x = torch.einsum("nac,ba->nbc", x, self.mappingReduced.to_m) + x._m_primary(mappingReduced) # Compute m=0 coefficients separately since they only have real values (no imaginary) # Compute edge scalar features for m=0 x_edge_0 = self.act(self.fc1_dist0(x_edge)) - x_0 = x[:, 0 : self.mappingReduced.m_size[0]].contiguous() + x_0 = x.embedding[:, 0 : mappingReduced.m_size[0]].contiguous() x_0 = x_0.view(num_edges, -1) x_0 = self.fc1_m0(x_0) x_0 = x_0 * x_edge_0 x_0 = self.fc2_m0(x_0) - x_0 = x_0.view(num_edges, -1, self.sphere_channels) + x_0 = x_0.view(num_edges, -1, x.num_channels) # Update the m=0 coefficients - x[:, 0 : self.mappingReduced.m_size[0]] = x_0 + x.embedding[:, 0 : mappingReduced.m_size[0]] = x_0 # Compute the values for the m > 0 coefficients - offset = self.mappingReduced.m_size[0] + offset = mappingReduced.m_size[0] for m in range(1, max(self.mmax_list) + 1): # Get the m order coefficients - x_m = x[ - :, offset : offset + 2 * self.mappingReduced.m_size[m] + x_m = x.embedding[ + :, offset : offset + 2 * mappingReduced.m_size[m] ].contiguous() x_m = x_m.view(num_edges, 2, -1) # Perform SO(2) convolution x_m = self.so2_conv[m - 1](x_m, x_edge) - x_m = x_m.view(num_edges, -1, self.sphere_channels) - x[:, offset : offset + 2 * self.mappingReduced.m_size[m]] = x_m + x_m = x_m.view(num_edges, -1, x.num_channels) + x.embedding[:, offset : offset + 2 * mappingReduced.m_size[m]] = x_m - offset = offset + 2 * self.mappingReduced.m_size[m] + offset = offset + 2 * mappingReduced.m_size[m] # Reshape the spherical harmonics based on l (degree) - # x._l_primary(self.mappingReduced) - x = torch.einsum("nac,ab->nbc", x, self.mappingReduced.to_m) + x._l_primary(mappingReduced) return x diff --git a/src/fairchem/core/models/escn/escn_exportable.py b/src/fairchem/core/models/escn/escn_exportable.py new file mode 100644 index 0000000000..1f01fadd59 --- /dev/null +++ b/src/fairchem/core/models/escn/escn_exportable.py @@ -0,0 +1,964 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import contextlib +import logging +import time +import typing + +import torch +import torch.nn as nn + +if typing.TYPE_CHECKING: + from torch_geometric.data.batch import Batch + +from fairchem.core.common.registry import registry +from fairchem.core.common.utils import conditional_grad +from fairchem.core.models.base import BackboneInterface, GraphModelMixin, HeadInterface +from fairchem.core.models.utils.so3_utils import ( + CoefficientMapping, + SO3_Grid, + SO3_Rotation, +) +from fairchem.core.models.escn.so3 import SO3_Embedding +from fairchem.core.models.scn.sampling import CalcSpherePoints +from fairchem.core.models.scn.smearing import ( + GaussianSmearing, + LinearSigmoidSmearing, + SigmoidSmearing, + SiLUSmearing, +) + +with contextlib.suppress(ImportError): + from e3nn import o3 + + +@registry.register_model("escn_export") +class eSCN(nn.Module, GraphModelMixin): + """Equivariant Spherical Channel Network + Paper: Reducing SO(3) Convolutions to SO(2) for Efficient Equivariant GNNs + + + Args: + use_pbc (bool): Use periodic boundary conditions + use_pbc_single (bool): Process batch PBC graphs one at a time + regress_forces (bool): Compute forces + otf_graph (bool): Compute graph On The Fly (OTF) + max_neighbors (int): Maximum number of neighbors per atom + cutoff (float): Maximum distance between nieghboring atoms in Angstroms + max_num_elements (int): Maximum atomic number + + num_layers (int): Number of layers in the GNN + lmax_list (int): List of maximum degree of the spherical harmonics (1 to 10) + mmax_list (int): List of maximum order of the spherical harmonics (0 to lmax) + sphere_channels (int): Number of spherical channels (one set per resolution) + hidden_channels (int): Number of hidden units in message passing + num_sphere_samples (int): Number of samples used to approximate the integration of the sphere in the output blocks + edge_channels (int): Number of channels for the edge invariant features + distance_function ("gaussian", "sigmoid", "linearsigmoid", "silu"): Basis function used for distances + basis_width_scalar (float): Width of distance basis function + distance_resolution (float): Distance between distance basis functions in Angstroms + show_timing_info (bool): Show timing and memory info + """ + + def __init__( + self, + use_pbc: bool = True, + use_pbc_single: bool = False, + regress_forces: bool = True, + otf_graph: bool = False, + max_neighbors: int = 40, + cutoff: float = 8.0, + max_num_elements: int = 90, + num_layers: int = 8, + lmax_list: list[int] | None = None, + mmax_list: list[int] | None = None, + sphere_channels: int = 128, + hidden_channels: int = 256, + edge_channels: int = 128, + num_sphere_samples: int = 128, + distance_function: str = "gaussian", + basis_width_scalar: float = 1.0, + distance_resolution: float = 0.02, + show_timing_info: bool = False, + resolution: int | None = None, + ) -> None: + if mmax_list is None: + mmax_list = [2] + if lmax_list is None: + lmax_list = [6] + super().__init__() + + import sys + + if "e3nn" not in sys.modules: + logging.error("You need to install the e3nn library to use the SCN model") + raise ImportError + + self.regress_forces = regress_forces + self.use_pbc = use_pbc + self.use_pbc_single = use_pbc_single + self.cutoff = cutoff + self.otf_graph = otf_graph + self.show_timing_info = show_timing_info + self.max_num_elements = max_num_elements + self.hidden_channels = hidden_channels + self.num_layers = num_layers + self.num_atoms = 0 + self.num_sphere_samples = num_sphere_samples + self.sphere_channels = sphere_channels + self.max_neighbors = max_neighbors + self.edge_channels = edge_channels + self.distance_resolution = distance_resolution + self.grad_forces = False + self.lmax_list = lmax_list + self.mmax_list = mmax_list + self.num_resolutions: int = len(self.lmax_list) + self.sphere_channels_all: int = self.num_resolutions * self.sphere_channels + self.basis_width_scalar = basis_width_scalar + self.distance_function = distance_function + + # variables used for display purposes + self.counter = 0 + + # non-linear activation function used throughout the network + self.act = nn.SiLU() + + # Weights for message initialization + self.sphere_embedding = nn.Embedding( + self.max_num_elements, self.sphere_channels_all + ) + + # Initialize the function used to measure the distances between atoms + assert self.distance_function in [ + "gaussian", + "sigmoid", + "linearsigmoid", + "silu", + ] + self.num_gaussians = int(cutoff / self.distance_resolution) + if self.distance_function == "gaussian": + self.distance_expansion = GaussianSmearing( + 0.0, + cutoff, + self.num_gaussians, + basis_width_scalar, + ) + if self.distance_function == "sigmoid": + self.distance_expansion = SigmoidSmearing( + 0.0, + cutoff, + self.num_gaussians, + basis_width_scalar, + ) + if self.distance_function == "linearsigmoid": + self.distance_expansion = LinearSigmoidSmearing( + 0.0, + cutoff, + self.num_gaussians, + basis_width_scalar, + ) + if self.distance_function == "silu": + self.distance_expansion = SiLUSmearing( + 0.0, + cutoff, + self.num_gaussians, + basis_width_scalar, + ) + + # Initialize the transformations between spherical and grid representations + assert self.num_resolutions == 1, "Only one resolution is supported" + self.SO3_grid = nn.ModuleDict() + self.SO3_grid["lmax_lmax"] = SO3_Grid(self.lmax_list[0], self.lmax_list[0], resolution=resolution) + self.SO3_grid["lmax_mmax"] = SO3_Grid(self.lmax_list[0], self.mmax_list[0], resolution=resolution) + # self.SO3_grid = nn.ModuleList() + # for lval in range(max(self.lmax_list) + 1): + # SO3_m_grid = nn.ModuleList() + # for m in range(max(self.lmax_list) + 1): + # SO3_m_grid.append(SO3_Grid(lval, m, resolution=resolution)) + + # self.SO3_grid.append(SO3_m_grid) + # import pdb;pdb.set_trace() + self.mappingReduced = CoefficientMapping(self.lmax_list, self.mmax_list) + + # Initialize the blocks for each layer of the GNN + self.layer_blocks = nn.ModuleList() + for i in range(self.num_layers): + block = LayerBlock( + i, + self.sphere_channels, + self.hidden_channels, + self.edge_channels, + self.lmax_list, + self.mmax_list, + self.distance_expansion, + self.max_num_elements, + self.SO3_grid, + self.act, + self.mappingReduced + ) + self.layer_blocks.append(block) + + # Output blocks for energy and forces + self.energy_block = EnergyBlock( + self.sphere_channels_all, self.num_sphere_samples, self.act + ) + if self.regress_forces: + self.force_block = ForceBlock( + self.sphere_channels_all, self.num_sphere_samples, self.act + ) + + # Create a roughly evenly distributed point sampling of the sphere for the output blocks + self.sphere_points = nn.Parameter( + CalcSpherePoints(self.num_sphere_samples), requires_grad=False + ) + + # For each spherical point, compute the spherical harmonic coefficient weights + sphharm_weights: list[nn.Parameter] = [] + for i in range(self.num_resolutions): + sphharm_weights.append( + nn.Parameter( + o3.spherical_harmonics( + torch.arange(0, self.lmax_list[i] + 1).tolist(), + self.sphere_points, + False, + ), + requires_grad=False, + ) + ) + self.sphharm_weights = nn.ParameterList(sphharm_weights) + + + @conditional_grad(torch.enable_grad()) + def forward(self, data): + device = data.pos.device + self.batch_size = len(data.natoms) + self.dtype = data.pos.dtype + + start_time = time.time() + atomic_numbers = data.atomic_numbers.long() + assert ( + atomic_numbers.max().item() < self.max_num_elements + ), "Atomic number exceeds that given in model config" + num_atoms = len(atomic_numbers) + graph = self.generate_graph(data) + + ############################################################### + # Initialize data structures + ############################################################### + + # Compute 3x3 rotation matrix per edge + edge_rot_mat = self._init_edge_rot_mat( + data, graph.edge_index, graph.edge_distance_vec + ) + + # Initialize the WignerD matrices and other values for spherical harmonic calculations + self.SO3_edge_rot = SO3_Rotation(edge_rot_mat, self.lmax_list[0]) + + ############################################################### + # Initialize node embeddings + ############################################################### + + # Init per node representations using an atomic number based embedding + offset = 0 + x = SO3_Embedding( + num_atoms, + self.lmax_list, + self.sphere_channels, + device, + self.dtype, + ) + + offset_res = 0 + offset = 0 + # Initialize the l=0,m=0 coefficients for each resolution + for i in range(self.num_resolutions): + x.embedding[:, offset_res, :] = self.sphere_embedding(atomic_numbers)[ + :, offset : offset + self.sphere_channels + ] + offset = offset + self.sphere_channels + offset_res = offset_res + int((self.lmax_list[i] + 1) ** 2) + + ############################################################### + # Update spherical node embeddings + ############################################################### + x_message = x.embedding + for i in range(self.num_layers): + if i > 0: + x_message_new = self.layer_blocks[i]( + x_message, + atomic_numbers, + graph.edge_distance, + graph.edge_index, + self.SO3_edge_rot, + ) + + # Residual layer for all layers past the first + x_xessage = x_message + x_message_new + + else: + # No residual for the first layer + x_message = self.layer_blocks[i]( + x_message, + atomic_numbers, + graph.edge_distance, + graph.edge_index, + self.SO3_edge_rot, + ) + x.embedding = x_message + + # Sample the spherical channels (node embeddings) at evenly distributed points on the sphere. + # These values are fed into the output blocks. + x_pt = torch.tensor([], device=device) + offset = 0 + # Compute the embedding values at every sampled point on the sphere + for i in range(self.num_resolutions): + num_coefficients = int((x.lmax_list[i] + 1) ** 2) + x_pt = torch.cat( + [ + x_pt, + torch.einsum( + "abc, pb->apc", + x.embedding[:, offset : offset + num_coefficients], + self.sphharm_weights[i], + ).contiguous(), + ], + dim=2, + ) + offset = offset + num_coefficients + + x_pt = x_pt.view(-1, self.sphere_channels_all) + + ############################################################### + # Energy estimation + ############################################################### + node_energy = self.energy_block(x_pt) + energy = torch.zeros(len(data.natoms), device=device) + energy.index_add_(0, data.batch, node_energy.view(-1)) + # Scale energy to help balance numerical precision w.r.t. forces + energy = energy * 0.001 + + outputs = {"energy": energy} + ############################################################### + # Force estimation + ############################################################### + if self.regress_forces: + forces = self.force_block(x_pt, self.sphere_points) + outputs["forces"] = forces + + if self.show_timing_info is True: + torch.cuda.synchronize() + logging.info( + f"{self.counter} Time: {time.time() - start_time}\tMemory: {len(data.pos)}\t{torch.cuda.max_memory_allocated() / 1000000}" + ) + + self.counter = self.counter + 1 + + return outputs + + # Initialize the edge rotation matrics + def _init_edge_rot_mat(self, data, edge_index, edge_distance_vec): + edge_vec_0 = edge_distance_vec + edge_vec_0_distance = torch.sqrt(torch.sum(edge_vec_0**2, dim=1)) + + # Make sure the atoms are far enough apart + if torch.min(edge_vec_0_distance) < 0.0001: + logging.error( + f"Error edge_vec_0_distance: {torch.min(edge_vec_0_distance)}" + ) + (minval, minidx) = torch.min(edge_vec_0_distance, 0) + logging.error( + f"Error edge_vec_0_distance: {minidx} {edge_index[0, minidx]} {edge_index[1, minidx]} {data.pos[edge_index[0, minidx]]} {data.pos[edge_index[1, minidx]]}" + ) + + norm_x = edge_vec_0 / (edge_vec_0_distance.view(-1, 1)) + + edge_vec_2 = torch.rand_like(edge_vec_0) - 0.5 + edge_vec_2 = edge_vec_2 / ( + torch.sqrt(torch.sum(edge_vec_2**2, dim=1)).view(-1, 1) + ) + # Create two rotated copys of the random vectors in case the random vector is aligned with norm_x + # With two 90 degree rotated vectors, at least one should not be aligned with norm_x + edge_vec_2b = edge_vec_2.clone() + edge_vec_2b[:, 0] = -edge_vec_2[:, 1] + edge_vec_2b[:, 1] = edge_vec_2[:, 0] + edge_vec_2c = edge_vec_2.clone() + edge_vec_2c[:, 1] = -edge_vec_2[:, 2] + edge_vec_2c[:, 2] = edge_vec_2[:, 1] + vec_dot_b = torch.abs(torch.sum(edge_vec_2b * norm_x, dim=1)).view(-1, 1) + vec_dot_c = torch.abs(torch.sum(edge_vec_2c * norm_x, dim=1)).view(-1, 1) + + vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)).view(-1, 1) + edge_vec_2 = torch.where(torch.gt(vec_dot, vec_dot_b), edge_vec_2b, edge_vec_2) + vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)).view(-1, 1) + edge_vec_2 = torch.where(torch.gt(vec_dot, vec_dot_c), edge_vec_2c, edge_vec_2) + + vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)) + # Check the vectors aren't aligned + assert torch.max(vec_dot) < 0.99 + + norm_z = torch.cross(norm_x, edge_vec_2, dim=1) + norm_z = norm_z / (torch.sqrt(torch.sum(norm_z**2, dim=1, keepdim=True))) + norm_z = norm_z / (torch.sqrt(torch.sum(norm_z**2, dim=1)).view(-1, 1)) + norm_y = torch.cross(norm_x, norm_z, dim=1) + norm_y = norm_y / (torch.sqrt(torch.sum(norm_y**2, dim=1, keepdim=True))) + + # Construct the 3D rotation matrix + norm_x = norm_x.view(-1, 3, 1) + norm_y = -norm_y.view(-1, 3, 1) + norm_z = norm_z.view(-1, 3, 1) + + edge_rot_mat_inv = torch.cat([norm_z, norm_x, norm_y], dim=2) + edge_rot_mat = torch.transpose(edge_rot_mat_inv, 1, 2) + + return edge_rot_mat.detach() + + @property + def num_params(self) -> int: + return sum(p.numel() for p in self.parameters()) + +class LayerBlock(torch.nn.Module): + """ + Layer block: Perform one layer (message passing and aggregation) of the GNN + + Args: + layer_idx (int): Layer number + sphere_channels (int): Number of spherical channels + hidden_channels (int): Number of hidden channels used during the SO(2) conv + edge_channels (int): Size of invariant edge embedding + lmax_list (list:int): List of degrees (l) for each resolution + mmax_list (list:int): List of orders (m) for each resolution + distance_expansion (func): Function used to compute distance embedding + max_num_elements (int): Maximum number of atomic numbers + SO3_grid (SO3_grid): Class used to convert from grid the spherical harmonic representations + act (function): Non-linear activation function + """ + + def __init__( + self, + layer_idx: int, + sphere_channels: int, + hidden_channels: int, + edge_channels: int, + lmax_list: list[int], + mmax_list: list[int], + distance_expansion, + max_num_elements: int, + SO3_grid: SO3_Grid, + act, + mappingReduced, + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.act = act + self.lmax_list = lmax_list + self.mmax_list = mmax_list + self.num_resolutions = len(lmax_list) + self.sphere_channels = sphere_channels + self.sphere_channels_all = self.num_resolutions * self.sphere_channels + self.SO3_grid = SO3_grid + self.mappingReduced = mappingReduced + + # Message block + self.message_block = MessageBlock( + self.layer_idx, + self.sphere_channels, + hidden_channels, + edge_channels, + self.lmax_list, + self.mmax_list, + distance_expansion, + max_num_elements, + self.SO3_grid, + self.act, + self.mappingReduced + ) + + # Non-linear point-wise comvolution for the aggregated messages + self.fc1_sphere = nn.Linear( + 2 * self.sphere_channels_all, self.sphere_channels_all, bias=False + ) + + self.fc2_sphere = nn.Linear( + self.sphere_channels_all, self.sphere_channels_all, bias=False + ) + + self.fc3_sphere = nn.Linear( + self.sphere_channels_all, self.sphere_channels_all, bias=False + ) + + def forward( + self, + x: torch.Tensor, + atomic_numbers: torch.Tensor, + edge_distance: torch.Tensor, + edge_index: torch.Tensor, + SO3_edge_rot: SO3_Rotation, + ) -> torch.Tensor: + # Compute messages by performing message block + x_message = self.message_block( + x, + atomic_numbers, + edge_distance, + edge_index, + SO3_edge_rot, + ) + print(f"x_message: {x_message.mean()}") + + # Compute point-wise spherical non-linearity on aggregated messages + + # Project to grid + # x_grid_message = x_message.to_grid(self.SO3_grid["lmax_lmax"]) + to_grid_mat = self.SO3_grid["lmax_lmax"].to_grid_mat[:, :, self.SO3_grid["lmax_lmax"].mapping.coefficient_idx(self.lmax_list[0], self.lmax_list[0])] + x_grid_message = torch.einsum("bai,zic->zbac", to_grid_mat, x_message) + + # x_grid = x.to_grid(self.SO3_grid["lmax_lmax"]) + x_grid = torch.einsum("bai,zic->zbac", to_grid_mat, x) + x_grid = torch.cat([x_grid, x_grid_message], dim=3) + + # Perform point-wise convolution + x_grid = self.act(self.fc1_sphere(x_grid)) + x_grid = self.act(self.fc2_sphere(x_grid)) + x_grid = self.fc3_sphere(x_grid) + print(f"x_grid: {x_grid.mean()}") + + # Project back to spherical harmonic coefficients + # x_message._from_grid(x_grid, self.SO3_grid["lmax_lmax"]) + from_grid_mat = self.SO3_grid["lmax_lmax"].from_grid_mat[:, :, self.SO3_grid["lmax_lmax"].mapping.coefficient_idx(self.lmax_list[0], self.lmax_list[0])] + x_message_final = torch.einsum("bai,zbac->zic", from_grid_mat, x_grid) + + print(f"x_message_final: {x_message_final.mean()}") + # Return aggregated messages + return x_message_final + + +class MessageBlock(torch.nn.Module): + """ + Message block: Perform message passing + + Args: + layer_idx (int): Layer number + sphere_channels (int): Number of spherical channels + hidden_channels (int): Number of hidden channels used during the SO(2) conv + edge_channels (int): Size of invariant edge embedding + lmax_list (list:int): List of degrees (l) for each resolution + mmax_list (list:int): List of orders (m) for each resolution + distance_expansion (func): Function used to compute distance embedding + max_num_elements (int): Maximum number of atomic numbers + SO3_grid (SO3_grid): Class used to convert from grid the spherical harmonic representations + act (function): Non-linear activation function + """ + + def __init__( + self, + layer_idx: int, + sphere_channels: int, + hidden_channels: int, + edge_channels: int, + lmax_list: list[int], + mmax_list: list[int], + distance_expansion, + max_num_elements: int, + SO3_grid: SO3_Grid, + act, + mappingReduced, + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.act = act + self.hidden_channels = hidden_channels + self.sphere_channels = sphere_channels + self.SO3_grid = SO3_grid + self.num_resolutions = len(lmax_list) + self.lmax_list = lmax_list + self.mmax_list = mmax_list + self.edge_channels = edge_channels + self.mappingReduced = mappingReduced + + # Create edge scalar (invariant to rotations) features + self.edge_block = EdgeBlock( + self.edge_channels, + distance_expansion, + max_num_elements, + self.act, + ) + + # Create SO(2) convolution blocks + self.so2_block_source = SO2Block( + self.sphere_channels, + self.hidden_channels, + self.edge_channels, + self.lmax_list, + self.mmax_list, + self.act, + self.mappingReduced + ) + self.so2_block_target = SO2Block( + self.sphere_channels, + self.hidden_channels, + self.edge_channels, + self.lmax_list, + self.mmax_list, + self.act, + self.mappingReduced + ) + + def forward( + self, + x: torch.Tensor, + atomic_numbers: torch.Tensor, + edge_distance: torch.Tensor, + edge_index: torch.Tensor, + SO3_edge_rot: SO3_Rotation, + ) -> torch.Tensor: + ############################################################### + # Compute messages + ############################################################### + # Compute edge scalar features (invariant to rotations) + # Uses atomic numbers and edge distance as inputs + x_edge = self.edge_block( + edge_distance, + atomic_numbers[edge_index[0]], # Source atom atomic number + atomic_numbers[edge_index[1]], # Target atom atomic number + ) + + # Copy embeddings for each edge's source and target nodes + x_source = x.clone() + x_target = x.clone() + x_source = x_source[edge_index[0, :]] + x_target = x_target[edge_index[1, :]] + + # Rotate the irreps to align with the edge + x_source = SO3_edge_rot.rotate(x_source, self.lmax_list[0], self.mmax_list[0]) + x_target = SO3_edge_rot.rotate(x_target, self.lmax_list[0], self.mmax_list[0]) + + # Compute messages + x_source = self.so2_block_source(x_source, x_edge) + x_target = self.so2_block_target(x_target, x_edge) + + # Add together the source and target results + x_target = x_source + x_target + + # Point-wise spherical non-linearity + to_grid_mat = self.SO3_grid["lmax_mmax"].get_to_grid_mat() + from_grid_mat = self.SO3_grid["lmax_mmax"].get_from_grid_mat() + x_grid = torch.einsum("bai,zic->zbac", to_grid_mat, x_target) + x_grid = self.act(x_grid) + x_target = torch.einsum("bai,zbac->zic", from_grid_mat, x_grid) + + # Rotate back the irreps + x_target = SO3_edge_rot.rotate_inv(x_target, self.lmax_list[0], self.mmax_list[0]) + + # Compute the sum of the incoming neighboring messages for each target node + new_embedding = torch.fill(x.clone(), 0) + new_embedding.index_add_(0, edge_index[1], x_target) + # x_target._reduce_edge(edge_index[1], len(x.embedding)) + + return new_embedding + + +class SO2Block(torch.nn.Module): + """ + SO(2) Block: Perform SO(2) convolutions for all m (orders) + + Args: + sphere_channels (int): Number of spherical channels + hidden_channels (int): Number of hidden channels used during the SO(2) conv + edge_channels (int): Size of invariant edge embedding + lmax_list (list:int): List of degrees (l) for each resolution + mmax_list (list:int): List of orders (m) for each resolution + act (function): Non-linear activation function + """ + + def __init__( + self, + sphere_channels: int, + hidden_channels: int, + edge_channels: int, + lmax_list: list[int], + mmax_list: list[int], + act, + mappingReduced + ) -> None: + super().__init__() + self.sphere_channels = sphere_channels + self.hidden_channels = hidden_channels + self.lmax_list = lmax_list + self.mmax_list = mmax_list + self.num_resolutions: int = len(lmax_list) + self.act = act + self.mappingReduced = mappingReduced + + num_channels_m0 = 0 + for i in range(self.num_resolutions): + num_coefficents = self.lmax_list[i] + 1 + num_channels_m0 = num_channels_m0 + num_coefficents * self.sphere_channels + + # SO(2) convolution for m=0 + self.fc1_dist0 = nn.Linear(edge_channels, self.hidden_channels) + self.fc1_m0 = nn.Linear(num_channels_m0, self.hidden_channels, bias=False) + self.fc2_m0 = nn.Linear(self.hidden_channels, num_channels_m0, bias=False) + + # SO(2) convolution for non-zero m + self.so2_conv = nn.ModuleList() + for m in range(1, max(self.mmax_list) + 1): + so2_conv = SO2Conv( + m, + self.sphere_channels, + self.hidden_channels, + edge_channels, + self.lmax_list, + self.mmax_list, + self.act, + ) + self.so2_conv.append(so2_conv) + + def forward( + self, + x: torch.Tensor, + x_edge: torch.Tensor, + ): + num_edges = len(x_edge) + + # Reshape the spherical harmonics based on m (order) + x = torch.einsum("nac,ba->nbc", x, self.mappingReduced.to_m) + + # Compute m=0 coefficients separately since they only have real values (no imaginary) + + # Compute edge scalar features for m=0 + x_edge_0 = self.act(self.fc1_dist0(x_edge)) + + x_0 = x[:, 0 : self.mappingReduced.m_size[0]].contiguous() + x_0 = x_0.view(num_edges, -1) + + x_0 = self.fc1_m0(x_0) + x_0 = x_0 * x_edge_0 + x_0 = self.fc2_m0(x_0) + x_0 = x_0.view(num_edges, -1, self.sphere_channels) + + # Update the m=0 coefficients + x[:, 0 : self.mappingReduced.m_size[0]] = x_0 + + # Compute the values for the m > 0 coefficients + offset = self.mappingReduced.m_size[0] + for m in range(1, max(self.mmax_list) + 1): + # Get the m order coefficients + x_m = x[ + :, offset : offset + 2 * self.mappingReduced.m_size[m] + ].contiguous() + x_m = x_m.view(num_edges, 2, -1) + # Perform SO(2) convolution + x_m = self.so2_conv[m - 1](x_m, x_edge) + x_m = x_m.view(num_edges, -1, self.sphere_channels) + x[:, offset : offset + 2 * self.mappingReduced.m_size[m]] = x_m + + offset = offset + 2 * self.mappingReduced.m_size[m] + + # Reshape the spherical harmonics based on l (degree) + # x._l_primary(self.mappingReduced) + x = torch.einsum("nac,ab->nbc", x, self.mappingReduced.to_m) + + return x + + +class SO2Conv(torch.nn.Module): + """ + SO(2) Conv: Perform an SO(2) convolution + + Args: + m (int): Order of the spherical harmonic coefficients + sphere_channels (int): Number of spherical channels + hidden_channels (int): Number of hidden channels used during the SO(2) conv + edge_channels (int): Size of invariant edge embedding + lmax_list (list:int): List of degrees (l) for each resolution + mmax_list (list:int): List of orders (m) for each resolution + act (function): Non-linear activation function + """ + + def __init__( + self, + m: int, + sphere_channels: int, + hidden_channels: int, + edge_channels: int, + lmax_list: list[int], + mmax_list: list[int], + act, + ) -> None: + super().__init__() + self.hidden_channels = hidden_channels + self.lmax_list = lmax_list + self.mmax_list = mmax_list + self.sphere_channels = sphere_channels + self.num_resolutions: int = len(self.lmax_list) + self.m = m + self.act = act + + num_channels = 0 + for i in range(self.num_resolutions): + num_coefficents = 0 + if self.mmax_list[i] >= m: + num_coefficents = self.lmax_list[i] - m + 1 + + num_channels = num_channels + num_coefficents * self.sphere_channels + + assert num_channels > 0 + + # Embedding function of the distance + self.fc1_dist = nn.Linear(edge_channels, 2 * self.hidden_channels) + + # Real weights of SO(2) convolution + self.fc1_r = nn.Linear(num_channels, self.hidden_channels, bias=False) + self.fc2_r = nn.Linear(self.hidden_channels, num_channels, bias=False) + + # Imaginary weights of SO(2) convolution + self.fc1_i = nn.Linear(num_channels, self.hidden_channels, bias=False) + self.fc2_i = nn.Linear(self.hidden_channels, num_channels, bias=False) + + def forward(self, x_m, x_edge) -> torch.Tensor: + # Compute edge scalar features + x_edge = self.act(self.fc1_dist(x_edge)) + x_edge = x_edge.view(-1, 2, self.hidden_channels) + + # Perform the complex weight multiplication + x_r = self.fc1_r(x_m) + x_r = x_r * x_edge[:, 0:1, :] + x_r = self.fc2_r(x_r) + + x_i = self.fc1_i(x_m) + x_i = x_i * x_edge[:, 1:2, :] + x_i = self.fc2_i(x_i) + + x_m_r = x_r[:, 0] - x_i[:, 1] + x_m_i = x_r[:, 1] + x_i[:, 0] + + return torch.stack((x_m_r, x_m_i), dim=1).contiguous() + + +class EdgeBlock(torch.nn.Module): + """ + Edge Block: Compute invariant edge representation from edge diatances and atomic numbers + + Args: + edge_channels (int): Size of invariant edge embedding + distance_expansion (func): Function used to compute distance embedding + max_num_elements (int): Maximum number of atomic numbers + act (function): Non-linear activation function + """ + + def __init__( + self, + edge_channels, + distance_expansion, + max_num_elements, + act, + ) -> None: + super().__init__() + self.in_channels = distance_expansion.num_output + self.distance_expansion = distance_expansion + self.act = act + self.edge_channels = edge_channels + self.max_num_elements = max_num_elements + + # Embedding function of the distance + self.fc1_dist = nn.Linear(self.in_channels, self.edge_channels) + + # Embedding function of the atomic numbers + self.source_embedding = nn.Embedding(self.max_num_elements, self.edge_channels) + self.target_embedding = nn.Embedding(self.max_num_elements, self.edge_channels) + nn.init.uniform_(self.source_embedding.weight.data, -0.001, 0.001) + nn.init.uniform_(self.target_embedding.weight.data, -0.001, 0.001) + + # Embedding function of the edge + self.fc1_edge_attr = nn.Linear( + self.edge_channels, + self.edge_channels, + ) + + def forward(self, edge_distance, source_element, target_element): + # Compute distance embedding + x_dist = self.distance_expansion(edge_distance) + x_dist = self.fc1_dist(x_dist) + + # Compute atomic number embeddings + source_embedding = self.source_embedding(source_element) + target_embedding = self.target_embedding(target_element) + + # Compute invariant edge embedding + x_edge = self.act(source_embedding + target_embedding + x_dist) + return self.act(self.fc1_edge_attr(x_edge)) + + +class EnergyBlock(torch.nn.Module): + """ + Energy Block: Output block computing the energy + + Args: + num_channels (int): Number of channels + num_sphere_samples (int): Number of samples used to approximate the integral on the sphere + act (function): Non-linear activation function + """ + + def __init__( + self, + num_channels: int, + num_sphere_samples: int, + act, + ) -> None: + super().__init__() + self.num_channels = num_channels + self.num_sphere_samples = num_sphere_samples + self.act = act + + self.fc1 = nn.Linear(self.num_channels, self.num_channels) + self.fc2 = nn.Linear(self.num_channels, self.num_channels) + self.fc3 = nn.Linear(self.num_channels, 1, bias=False) + + def forward(self, x_pt) -> torch.Tensor: + # x_pt are the values of the channels sampled at different points on the sphere + x_pt = self.act(self.fc1(x_pt)) + x_pt = self.act(self.fc2(x_pt)) + x_pt = self.fc3(x_pt) + x_pt = x_pt.view(-1, self.num_sphere_samples, 1) + return torch.sum(x_pt, dim=1) / self.num_sphere_samples + + +class ForceBlock(torch.nn.Module): + """ + Force Block: Output block computing the per atom forces + + Args: + num_channels (int): Number of channels + num_sphere_samples (int): Number of samples used to approximate the integral on the sphere + act (function): Non-linear activation function + """ + + def __init__( + self, + num_channels: int, + num_sphere_samples: int, + act, + ) -> None: + super().__init__() + self.num_channels = num_channels + self.num_sphere_samples = num_sphere_samples + self.act = act + + self.fc1 = nn.Linear(self.num_channels, self.num_channels) + self.fc2 = nn.Linear(self.num_channels, self.num_channels) + self.fc3 = nn.Linear(self.num_channels, 1, bias=False) + + def forward(self, x_pt, sphere_points) -> torch.Tensor: + # x_pt are the values of the channels sampled at different points on the sphere + x_pt = self.act(self.fc1(x_pt)) + x_pt = self.act(self.fc2(x_pt)) + x_pt = self.fc3(x_pt) + x_pt = x_pt.view(-1, self.num_sphere_samples, 1) + forces = x_pt * sphere_points.view(1, self.num_sphere_samples, 3) + return torch.sum(forces, dim=1) / self.num_sphere_samples diff --git a/src/fairchem/core/models/escn/so3.py b/src/fairchem/core/models/escn/so3.py index 0c674e2888..34f505d51e 100644 --- a/src/fairchem/core/models/escn/so3.py +++ b/src/fairchem/core/models/escn/so3.py @@ -241,12 +241,12 @@ def _rotate(self, SO3_rotation, lmax_list, mmax_list) -> None: self.set_lmax_mmax(lmax_list.copy(), mmax_list.copy()) # Rotate the embedding by the inverse of the rotation matrix - def _rotate_inv(self, SO3_rotation, res_size) -> None: + def _rotate_inv(self, SO3_rotation, mappingReduced) -> None: embedding_rotate = torch.tensor([], device=self.device, dtype=self.dtype) offset = 0 for i in range(self.num_resolutions): - num_coefficients = res_size[i] + num_coefficients = mappingReduced.res_size[i] embedding_i = self.embedding[:, offset : offset + num_coefficients] embedding_rotate = torch.cat( [ @@ -268,14 +268,18 @@ def _rotate_inv(self, SO3_rotation, res_size) -> None: self.set_lmax_mmax(self.lmax_list, self.mmax_list) # Compute point-wise spherical non-linearity - def _grid_act(self, SO3_grid, act, res_size) -> None: + def _grid_act(self, SO3_grid, act, mappingReduced) -> None: offset = 0 for i in range(self.num_resolutions): - num_coefficients = res_size[i] + num_coefficients = mappingReduced.res_size[i] x_res = self.embedding[:, offset : offset + num_coefficients].contiguous() - to_grid_mat = SO3_grid.get_to_grid_mat(self.device) - from_grid_mat = SO3_grid.get_from_grid_mat(self.device) + to_grid_mat = SO3_grid[self.lmax_list[i]][ + self.mmax_list[i] + ].get_to_grid_mat(self.device) + from_grid_mat = SO3_grid[self.lmax_list[i]][ + self.mmax_list[i] + ].get_from_grid_mat(self.device) x_grid = torch.einsum("bai,zic->zbac", to_grid_mat, x_res) x_grid = act(x_grid) @@ -289,8 +293,8 @@ def to_grid(self, SO3_grid, lmax: int = -1) -> torch.Tensor: if lmax == -1: lmax = max(self.lmax_list) - to_grid_mat_lmax = SO3_grid.get_to_grid_mat(self.device) - grid_mapping = SO3_grid.mapping + to_grid_mat_lmax = SO3_grid[lmax][lmax].get_to_grid_mat(self.device) + grid_mapping = SO3_grid[lmax][lmax].mapping offset = 0 x_grid = torch.tensor([], device=self.device) @@ -312,9 +316,12 @@ def to_grid(self, SO3_grid, lmax: int = -1) -> torch.Tensor: return x_grid # Compute irreps from grid representation - def _from_grid(self, x_grid, SO3_grid) -> None: - from_grid_mat_lmax = SO3_grid.get_from_grid_mat(self.device) - grid_mapping = SO3_grid.mapping + def _from_grid(self, x_grid, SO3_grid, lmax: int = -1) -> None: + if lmax == -1: + lmax = max(self.lmax_list) + + from_grid_mat_lmax = SO3_grid[lmax][lmax].get_from_grid_mat(self.device) + grid_mapping = SO3_grid[lmax][lmax].mapping offset = 0 offset_channel = 0 @@ -396,7 +403,7 @@ def RotationToWignerDMatrix( ) gamma = torch.atan2(R[..., 0, 2], R[..., 0, 0]) - size = int((end_lmax + 1) ** 2) - int((start_lmax) ** 2) + size = (end_lmax + 1) ** 2 - (start_lmax) ** 2 wigner = torch.zeros(len(alpha), size, size, device=self.device) start = 0 for lmax in range(start_lmax, end_lmax + 1): diff --git a/src/fairchem/core/models/utils/Jd.pt b/src/fairchem/core/models/utils/Jd.pt new file mode 100644 index 0000000000000000000000000000000000000000..01ed13e3f1c78bfcdf09077adc30aa3ab4a64685 GIT binary patch literal 21697 zcmdU13tWxa7C-4JglO{4qd`M $72cRX^$B$8C72yG6Ady!Yec%(55V;D3>7*{i9 zOvyNDxH6uFL`a^Ayk^Ek9(R9dxA!{V>U4j%JISej-*30RZ~ymRYkljr_t_m^m8qFR zVP&N#^;1V-s|cPN5*{%l%r!Jv6YLT-F|x0>jUqty6J0HkDB6Z-rqG@)LDMx6k<+Y$ zriP819uXNDqzRimZOYW386MWeq-|S&OFG9i+B5j0Fj9>{tb0$HJ}xqB7`tmtO#Mx0 zCpTi|Z%W&csOiL9WuhS#LsbfuiOM8~ShB`yX-LSF$jGn|O~jPR)2xR~4Go(b7TP-^ zL}N{=JE@7ahSX4#n&EWX@M_^~RCiJ<++0;%Lu{ncOjV{aq&98npmmmpOQ5%V~8VdDCk`VYONt{)TFHp)R8z*&@2skPXg6Y)rcYO zXhT8m8OT{fT-3x>2I@rIC}@XZ$6*kbbnGsQwH!Ktp`gL?uwwsIj=ak%1JJOksm0n4QWlh76_+g$-da zUkw?mCTcm%oeZO}y45tqPlDO2>|@Aq+ECaC2K!J${MBS+n2WE<(rW1G*bRkrKI!r; z(|K$Z_A=*jb?fV054n!ck0|5x=aEtx`fk0Vqmm;^*MA$jn|x^ZzTTY;4K4H2vIuE* z_<5v-zoox0Eq@9W9l&8&4i|0eT{=@8OUZQxO@W&S$X5N@Do^u>)w~BsIe;&{LHu5*o z56|O2Z)MG7|2yLO6F2iMuk*76UOukzQm)_J7(FTFijZS=lf)rUJkAMvwfD)Msx`NT zT$@)cANA|wOXB%Ri@zL4CnSmITvDt*-Tme_MZcqM$21+)@|qZ@;Fi1d@ZDL09`ZEu z{HH^5M1MZ6vA2$UMzv^=dsfJyyP!z z6*xA}7BSD>mUX6uetkvId|YF1KRp|8=-TsiAy=0t4^CWc0&=!_vTDHOuYukgzqi&c z_0OVzR{f$S`PD$bTwd(I*#5Bo81B5#-@n5Y#Xav{6*M2$*ju|^gO=@5CJFr%-F4lY zl${~yZ@<_y!D0SoK`(qA5$hLrUC;?#KDylP!VN)ld2#&1{)_F;p#GwBm3+J^(MCg_ z(OB=b+BqBI*W*K{Zgy-7e(Zhx{e~H_*+OpxX01jQ$L0t+#zKiAzF8{Qe>!o5zJGZMAQ!hQx_oUkj+<+c*X zO`P9!^XSEA9AB`1R26+<`@{NUa;y^iT47QcImjRK!^UQ{Pppc+A)a?wI_R3&SJ_%z zU*Y-y=QkXmaeS#N`o#W=?GNiuC)Z}s=h_V(zXZfV++=onNQvuHTwme(0OvOxpR0;K zas0#ni|r5VPbX*eNnxM&LziYvj@zJR_6PMnu1|4&h3kW|@|&^imFEx6Cy#%e=KjlR zZhxHS`r|Zf$LiLbr+nP1Lk)KwV4=QpN$8=(?u4A1PKUQSaSql;&o{QU znh*f%#9DqefBZ*Yrs$t|Zi43-r`=*at_iGPcK@2__kE2!&Fk&HDaPSA-UiRJZGF86 z_7V8~s@67W{ON^+s~3da7w$C*e6;wGpo{KXZp+yN>!pHb{l;!ih4udF-BvkH&q4oV zmf4=PzYBFP*9W#=gZhm=D)hamWz@q(Nni(!SJyOm7y46U_T`;e6UZUM`^#u@d4`z>z6$T z`(Ijpk6+Z>bAm67uPXF0%X2=%cfdM4Fgto?dIH$hwpKMv9y**7a<=pxab-bhg3$k; z8~nW1dtDT?-^b0bJ^T{lKi30}zc@Z%`!%fJBZ3b;ZdGezGuo^U7zX-2pEPpq^Fr|3 zjUAK64eJQ>{Rq2H_T}aXx!HOa^4zgkliGEd4CCQ^jPnJKXV~AcJz##Tn!e@yZWDI; zI`?()>sMMquSsvbIeh(bL2Ezz>V;bz#PK;mnPWf#F&``&ziHT=TS`9agjIh$ zn{Jz-#q~O_Cvm-l^D)jBIG*wOs#=TV1GZnRU!8mnpblIx$|0%q`m4gee^kX(KN|d# z60e7Fy^iZiT<_p~jPpfR(=+E6$6p*Du>E5F>c+X2ChU7<{Jp~Du$5Z8-o)!6T(9GL z64yI8A9MbzTFdhr=bOi0PV@M{X>Pxq=KAF{YkhV5A6}2U7c8lxP)JwV`yVz6^!dL> z#eM!S=c??_{~rl^!3e`NtV4DV^v|EY?_P<2|F7o@U{pDt-@nH6*WPnkI=v-b<j$)!d0`YYGLzAWdQwdK{P94%2Lfe|!eR^XFGEKiPs0wy&NiXy>%m z8FhBvQDXiu|D4~y+v54zc(A|0^?~0X=I`(F!SP&vPV@OW&DvC*{V!=?_IL?BmL^^0 zu>X#gVgEB-ZC1EWgZ+=#ZJjcA=R;l1;={2r>Dy|E z=MNf)?|e?f{J}W9Uq!bpE<)Xm?PPH?yX2LAaJ~Y+4=gVpucG;tkC!U^J=C_fPrv)H zk5lu+#*{V1_m$XQ@cY2>%Ezxb3iHQCE)e{`cn}iddH#|X+cUNo{5~p<=q>H*(yYqZq$%PN&4)1z&7ER(ii4^UpX;RB}$|9&Xj% z`&nrmfBvRgZ2xQ?h(Ct;Rw%Ik4f7BC1#gOf`M9v($@~e~=cgT}yj(qX$$s&k(h4Wfu^a>PbAIvrGb|tEbB?R**74_lpF2zk`I^M0 zdzc47-ubu&#C8Gb=g42kI5<57=I?KJk2&Ew6mOAH{p@F(Raok29R(XumjR zfORtD%l=z0MFot%r^Wt`{Q=tx)~9^_%A>G+x|yS4Ur*ER;E_|^ArEdnFzSQZ@wrMI z<yG@7N!(y;Por@|p1a!}2ltVI9xp{8fwNHI8T4-?2Z)xu`sf+b5rqzdufM`8dtm zRl5A!uU>0^EBa5a(p8T9+xczeUpfB@MbEXPG6w$wb&|cg!{jrIVc&x7Lmd_IoZ&YM zzn719K=e!cp~!sA=WsqC{u7sNN$Ie^$>xP~N%C>t9);~O=AS-um%j)3+uk>C?7@FQ zz0&SQLg0eQXN26Gb1c8^e)Mh8FZk*6tBHlj1zloqT8=lC_Qu()7U|0~1fTumGF&DG z!nysKrURxm-wNk#b{4$ezSFPtzBPXTSbipNhLFE|)FjhpufPtCMIR;SK!rYR>#HAo z2f;qJX8YK80`npNm*hLqpUZ*mAHRQt@<+pZjgRwI+i0+hbjyXymN$gD^<4KfkM_f$ zF2a75m}u6txYkW2jz72`wm+=@w<^EhP zkEOTzCw~8i<%e_Es<`KaF9+Tc{?Fpbc_q$2IR4`JgZ=NVDnV~@Y}a~^bbn&`*>kWT zw!iUHiGCpe;{1c-Z&~pN&tKUz_kTVww|`Fa_s?m${QPdTuKxP+))5D*>a_l%S2^mh zF2+!QtsA|i_3$~cj%by#)BLl6P$xXwzgAht5&kdnR4?a^nX91Ax)_;zVn)#_;Wu0} ztcCqx{C|l@cIWK#&Vc$oCC+c2y+!G~Eg$dgQ!d4;wk#VB>&o9p2KR24{hi3Ka|#!9 zOIiqW`97JIkm?HSGH2J96PJV^7X7(ISU>DO>>shue~t+~Z*OK6*Y*I6%f~UcmaVrg z3ckV5parh7Np^~w^? zg_i8U2z{A7o)mj|$00%E_5IP2SAYN04%Rt*eC%J?ehlm9gqVkqV{B~`;`L(rY=}>p z$DK}nH394e*R>^iP{=j%^sKzv+db=EvCdCFcj?7Tf^~3lBeRS4Z%&0dD)?IA9 z3G1??K!w9A%RDX4r#N5Wc!>QS+q1FhiSvW)2kVFROB41r^ZNiLxd;0icl!8ye3fuq ziR%rVPjSA$@eunv=g-($>|fY^uzqy>_kjP)`}BroBl}+xw74F_^#;zTIA7p+Xl#1o z{P6h0Y3^T~=JvyBt{+abHcD5&Uuvy9)V~J({E}W}*YEn@A2uUw>NNTVfPdrvzdfy` zZ(hsT^wDs7XigzrzA6)`5#NcZ|AlY#U^e=}-;jvW`rqkhBg1y5r)>1mu5K8s|9xt5 z>^h8f$5{RUrLQfU*PnjhqikOHQWNVRM3;lt*U)|To822>yI`!ouOQ>_1#F3D`l>7}+LxGLGgD__h torch.Tensor: + if not lv < len(_Jd): + raise NotImplementedError( + f"wigner D maximum l implemented is {len(_Jd) - 1}, send us an email to ask for more" + ) + + alpha, beta, gamma = torch.broadcast_tensors(alpha, beta, gamma) + J = _Jd[lv].to(dtype=alpha.dtype, device=alpha.device) + Xa = _z_rot_mat(alpha, lv) + Xb = _z_rot_mat(beta, lv) + Xc = _z_rot_mat(gamma, lv) + return Xa @ J @ Xb @ J @ Xc + + +def _z_rot_mat(angle: torch.Tensor, lv: int) -> torch.Tensor: + shape, device, dtype = angle.shape, angle.device, angle.dtype + M = angle.new_zeros((*shape, 2 * lv + 1, 2 * lv + 1)) + inds = torch.arange(0, 2 * lv + 1, 1, device=device) + reversed_inds = torch.arange(2 * lv, -1, -1, device=device) + frequencies = torch.arange(lv, -lv - 1, -1, dtype=dtype, device=device) + M[..., inds, reversed_inds] = torch.sin(frequencies * angle[..., None]) + M[..., inds, inds] = torch.cos(frequencies * angle[..., None]) + return M + + +class CoefficientMapping(torch.nn.Module): + """ + Helper module for coefficients used to reshape l <--> m and to get coefficients of specific degree or order + + Args: + lmax_list (list:int): List of maximum degree of the spherical harmonics + mmax_list (list:int): List of maximum order of the spherical harmonics + use_rotate_inv_rescale (bool): Whether to pre-compute inverse rotation rescale matrices + """ + + def __init__( + self, + lmax_list, + mmax_list, + use_rotate_inv_rescale=False + ): + super().__init__() + + self.lmax_list = lmax_list + self.mmax_list = mmax_list + self.use_rotate_inv_rescale = use_rotate_inv_rescale + self.num_resolutions = len(lmax_list) + + assert (len(self.lmax_list) == 1) and (len(self.mmax_list) == 1) + + # Compute the degree (l) and order (m) for each entry of the embedding + l_harmonic = torch.tensor([]).long() + m_harmonic = torch.tensor([]).long() + m_complex = torch.tensor([]).long() + + self.res_size = torch.zeros([self.num_resolutions]).long().tolist() + + offset = 0 + for i in range(self.num_resolutions): + for l in range(0, self.lmax_list[i] + 1): + mmax = min(self.mmax_list[i], l) + m = torch.arange(-mmax, mmax + 1).long() + m_complex = torch.cat([m_complex, m], dim=0) + m_harmonic = torch.cat( + [m_harmonic, torch.abs(m).long()], dim=0 + ) + l_harmonic = torch.cat( + [l_harmonic, m.fill_(l).long()], dim=0 + ) + self.res_size[i] = len(l_harmonic) - offset + offset = len(l_harmonic) + + num_coefficients = len(l_harmonic) + # `self.to_m` moves m components from different L to contiguous index + to_m = torch.zeros([num_coefficients, num_coefficients]) + self.m_size = torch.zeros([max(self.mmax_list) + 1]).long().tolist() + + offset = 0 + for m in range(max(self.mmax_list) + 1): + idx_r, idx_i = self.complex_idx(m, -1, m_complex, l_harmonic) + + for idx_out, idx_in in enumerate(idx_r): + to_m[idx_out + offset, idx_in] = 1.0 + offset = offset + len(idx_r) + + self.m_size[m] = int(len(idx_r)) + + for idx_out, idx_in in enumerate(idx_i): + to_m[idx_out + offset, idx_in] = 1.0 + offset = offset + len(idx_i) + + to_m = to_m.detach() + + # save tensors and they will be moved to GPU + self.register_buffer('l_harmonic', l_harmonic) + self.register_buffer('m_harmonic', m_harmonic) + self.register_buffer('m_complex', m_complex) + # self.register_buffer('res_size', res_size) + self.register_buffer('to_m', to_m) + # self.register_buffer('m_size', m_size) + + self.pre_compute_coefficient_idx() + if self.use_rotate_inv_rescale: + self.pre_compute_rotate_inv_rescale() + + + # Return mask containing coefficients of order m (real and imaginary parts) + def complex_idx(self, m, lmax, m_complex, l_harmonic): + ''' + Add `m_complex` and `l_harmonic` to the input arguments + since we cannot use `self.m_complex`. + ''' + if lmax == -1: + lmax = max(self.lmax_list) + + indices = torch.arange(len(l_harmonic)) + # Real part + mask_r = torch.bitwise_and( + l_harmonic.le(lmax), m_complex.eq(m) + ) + mask_idx_r = torch.masked_select(indices, mask_r) + + mask_idx_i = torch.tensor([]).long() + # Imaginary part + if m != 0: + mask_i = torch.bitwise_and( + l_harmonic.le(lmax), m_complex.eq(-m) + ) + mask_idx_i = torch.masked_select(indices, mask_i) + + return mask_idx_r, mask_idx_i + + + def pre_compute_coefficient_idx(self): + ''' + Pre-compute the results of `coefficient_idx()` and access them with `prepare_coefficient_idx()` + ''' + lmax = max(self.lmax_list) + for l in range(lmax + 1): + for m in range(lmax + 1): + mask = torch.bitwise_and( + self.l_harmonic.le(l), self.m_harmonic.le(m) + ) + indices = torch.arange(len(mask)) + mask_indices = torch.masked_select(indices, mask) + self.register_buffer('coefficient_idx_l{}_m{}'.format(l, m), mask_indices) + + + def prepare_coefficient_idx(self): + ''' + Construct a list of buffers + ''' + lmax = max(self.lmax_list) + coefficient_idx_list = [] + for l in range(lmax + 1): + l_list = [] + for m in range(lmax + 1): + l_list.append(getattr(self, 'coefficient_idx_l{}_m{}'.format(l, m), None)) + coefficient_idx_list.append(l_list) + return coefficient_idx_list + + + # Return mask containing coefficients less than or equal to degree (l) and order (m) + def coefficient_idx(self, lmax: int, mmax: int): + if lmax > max(self.lmax_list) or mmax > max(self.lmax_list): + mask = torch.bitwise_and( + self.l_harmonic.le(lmax), self.m_harmonic.le(mmax) + ) + indices = torch.arange(len(mask), device=mask.device) + mask_indices = torch.masked_select(indices, mask) + return mask_indices + else: + temp = self.prepare_coefficient_idx() + return temp[lmax][mmax] + + + def pre_compute_rotate_inv_rescale(self): + lmax = max(self.lmax_list) + for l in range(lmax + 1): + for m in range(lmax + 1): + mask_indices = self.coefficient_idx(l, m) + rotate_inv_rescale = torch.ones((1, int((l + 1)**2), int((l + 1)**2))) + for l_sub in range(l + 1): + if l_sub <= m: + continue + start_idx = l_sub ** 2 + length = 2 * l_sub + 1 + rescale_factor = math.sqrt(length / (2 * m + 1)) + rotate_inv_rescale[:, start_idx : (start_idx + length), start_idx : (start_idx + length)] = rescale_factor + rotate_inv_rescale = rotate_inv_rescale[:, :, mask_indices] + self.register_buffer('rotate_inv_rescale_l{}_m{}'.format(l, m), rotate_inv_rescale) + + + def prepare_rotate_inv_rescale(self): + lmax = max(self.lmax_list) + rotate_inv_rescale_list = [] + for l in range(lmax + 1): + l_list = [] + for m in range(lmax + 1): + l_list.append(getattr(self, 'rotate_inv_rescale_l{}_m{}'.format(l, m), None)) + rotate_inv_rescale_list.append(l_list) + return rotate_inv_rescale_list + + + # Return the re-scaling for rotating back to original frame + # this is required since we only use a subset of m components for SO(2) convolution + def get_rotate_inv_rescale(self, lmax, mmax): + temp = self.prepare_rotate_inv_rescale() + return temp[lmax][mmax] + + + def __repr__(self): + return f"{self.__class__.__name__}(lmax_list={self.lmax_list}, mmax_list={self.mmax_list})" + +class SO3_Rotation(torch.nn.Module): + """ + Helper functions for Wigner-D rotations + + Args: + rot_mat3x3 (tensor): Rotation matrix + lmax_list (list:int): List of maximum degree of the spherical harmonics + """ + + def __init__( + self, + rot_mat3x3: torch.Tensor, + lmax: list[int], + ) -> None: + super().__init__() + self.device = rot_mat3x3.device + self.dtype = rot_mat3x3.dtype + + self.wigner = self.RotationToWignerDMatrix(rot_mat3x3, 0, lmax) + self.wigner_inv = torch.transpose(self.wigner, 1, 2).contiguous() + + self.wigner = self.wigner.detach() + self.wigner_inv = self.wigner_inv.detach() + + self.set_lmax(lmax) + + # Initialize coefficients for reshape l<-->m + def set_lmax(self, lmax) -> None: + self.lmax = lmax + self.mapping = CoefficientMapping([self.lmax], [self.lmax], use_rotate_inv_rescale=True) + + # Rotate the embedding + def rotate(self, embedding, out_lmax, out_mmax) -> torch.Tensor: + out_mask = self.mapping.coefficient_idx(out_lmax, out_mmax) + wigner = self.wigner[:, out_mask, :] + return torch.bmm(wigner, embedding) + + # Rotate the embedding by the inverse of the rotation matrix + def rotate_inv(self, embedding, in_lmax, in_mmax) -> torch.Tensor: + in_mask = self.mapping.coefficient_idx(in_lmax, in_mmax) + wigner_inv = self.wigner_inv[:, :, in_mask] + + return torch.bmm(wigner_inv, embedding) + + # Compute Wigner matrices from rotation matrix + def RotationToWignerDMatrix( + self, edge_rot_mat: torch.Tensor, start_lmax: int, end_lmax: int + ) -> torch.Tensor: + x = edge_rot_mat @ edge_rot_mat.new_tensor([0.0, 1.0, 0.0]) + alpha, beta = o3.xyz_to_angles(x) + R = ( + o3.angles_to_matrix(alpha, beta, torch.zeros_like(alpha)).transpose(-1, -2) + @ edge_rot_mat + ) + gamma = torch.atan2(R[..., 0, 2], R[..., 0, 0]) + + size = int((end_lmax + 1) ** 2) - int((start_lmax) ** 2) + wigner = torch.zeros(len(alpha), size, size, device=self.device) + start = 0 + for lmax in range(start_lmax, end_lmax + 1): + block = wigner_D(lmax, alpha, beta, gamma) + end = start + block.size()[1] + wigner[:, start:end, start:end] = block + start = end + + return wigner.detach() + +class SO3_Grid(torch.nn.Module): + """ + Helper functions for grid representation of the irreps + + Args: + lmax (int): Maximum degree of the spherical harmonics + mmax (int): Maximum order of the spherical harmonics + """ + + def __init__( + self, + lmax: int, + mmax: int, + normalization: str = "integral", + resolution: int | None = None, + ): + super().__init__() + + self.lmax = lmax + self.mmax = mmax + self.lat_resolution = 2 * (self.lmax + 1) + if lmax == mmax: + self.long_resolution = 2 * (self.mmax + 1) + 1 + else: + self.long_resolution = 2 * (self.mmax) + 1 + if resolution is not None: + self.lat_resolution = resolution + self.long_resolution = resolution + + self.mapping = CoefficientMapping([self.lmax], [self.lmax]) + + device = "cpu" + + to_grid = ToS2Grid( + self.lmax, + (self.lat_resolution, self.long_resolution), + normalization=normalization, # normalization="integral", + device=device, + ) + to_grid_mat = torch.einsum("mbi, am -> bai", to_grid.shb, to_grid.sha).detach() + # rescale based on mmax + if lmax != mmax: + for lval in range(lmax + 1): + if lval <= mmax: + continue + start_idx = lval**2 + length = 2 * lval + 1 + rescale_factor = math.sqrt(length / (2 * mmax + 1)) + to_grid_mat[:, :, start_idx : (start_idx + length)] = ( + to_grid_mat[:, :, start_idx : (start_idx + length)] * rescale_factor + ) + to_grid_mat = to_grid_mat[ + :, :, self.mapping.coefficient_idx(self.lmax, self.mmax) + ] + + from_grid = FromS2Grid( + (self.lat_resolution, self.long_resolution), + self.lmax, + normalization=normalization, # normalization="integral", + device=device, + ) + from_grid_mat = torch.einsum( + "am, mbi -> bai", from_grid.sha, from_grid.shb + ).detach() + # rescale based on mmax + if lmax != mmax: + for lval in range(lmax + 1): + if lval <= mmax: + continue + start_idx = lval**2 + length = 2 * lval + 1 + rescale_factor = math.sqrt(length / (2 * mmax + 1)) + from_grid_mat[:, :, start_idx : (start_idx + length)] = ( + from_grid_mat[:, :, start_idx : (start_idx + length)] + * rescale_factor + ) + from_grid_mat = from_grid_mat[ + :, :, self.mapping.coefficient_idx(self.lmax, self.mmax) + ] + + # save tensors and they will be moved to GPU + self.register_buffer("to_grid_mat", to_grid_mat) + self.register_buffer("from_grid_mat", from_grid_mat) + + # Compute matrices to transform irreps to grid + def get_to_grid_mat(self, device=None): + return self.to_grid_mat + + # Compute matrices to transform grid to irreps + def get_from_grid_mat(self,device=None): + return self.from_grid_mat + + # Compute grid from irreps representation + def to_grid(self, embedding, lmax: int, mmax: int): + to_grid_mat = self.to_grid_mat[:, :, self.mapping.coefficient_idx(lmax, mmax)] + return torch.einsum("bai, zic -> zbac", to_grid_mat, embedding) + + # Compute irreps from grid representation + def from_grid(self, grid, lmax: int, mmax: int): + from_grid_mat = self.from_grid_mat[ + :, :, self.mapping.coefficient_idx(lmax, mmax) + ] + return torch.einsum("bai, zbac -> zic", from_grid_mat, grid) diff --git a/tests/core/models/test_escn_compiles.py b/tests/core/models/test_escn_compiles.py index 817a5e674b..b96e32e152 100644 --- a/tests/core/models/test_escn_compiles.py +++ b/tests/core/models/test_escn_compiles.py @@ -28,7 +28,7 @@ from fairchem.core.common.transforms import RandomRotate from fairchem.core.models.utils.so3_utils import CoefficientMapping -from fairchem.core.models.escn.escn import SO2Block +from fairchem.core.models.escn import escn_exportable from torch.export import export from torch.export import Dim @@ -52,10 +52,10 @@ def load_data(): return data_list[0] -def load_model(): +def load_model(name: str): torch.manual_seed(4) setup_imports() - model = registry.get_model_class("escn")( + model = registry.get_model_class(name)( use_pbc = True, use_pbc_single = False, regress_forces = True, @@ -78,16 +78,6 @@ def load_model(): ) return model -def expected_energy_forces(): - energy = torch.tensor([0.0001747000]) - forces = torch.tensor([-1.2720219900e-07, 8.2126695133e-07, -3.8776403244e-07]) - return energy, forces - -def expected_energy_forces_cuda(): - energy = torch.tensor([0.0001747000]) - forces = torch.tensor([-4.5273015559e-08, 9.0246174977e-07, -3.8560736471e-07]) - return energy, forces - def init(backend: str): if not torch.distributed.is_initialized(): init_local_distributed_process_group(backend=backend) @@ -97,26 +87,26 @@ def test_escn_baseline_cpu(self, tol=1e-5): init('gloo') data = load_data() data = data_list_collater([data]) - model = load_model() - ddp_model = DistributedDataParallel(model) - output = ddp_model(data) - expected_energy, expected_forces = expected_energy_forces() + base_model = DistributedDataParallel(load_model("escn")) + export_model = DistributedDataParallel(load_model("escn_export")) + base_output = base_model(data) + export_output = export_model(data) torch.set_printoptions(precision=8) - assert torch.allclose(output["energy"], expected_energy, atol=tol) - assert torch.allclose(output["forces"].mean(0), expected_forces, atol=tol) + assert torch.allclose(base_output["energy"], export_output["energy"], atol=tol) + assert torch.allclose(base_output["forces"].mean(0), export_output["forces"].mean(0), atol=tol) @skip_if_no_cuda def test_escn_baseline_cuda(self, tol=1e-5): init('nccl') data = load_data() data = data_list_collater([data]).to("cuda") - model = load_model().cuda() - ddp_model = DistributedDataParallel(model) - output = ddp_model(data) - expected_energy, expected_forces = expected_energy_forces_cuda() + base_model = DistributedDataParallel(load_model("escn")).cuda() + export_model = DistributedDataParallel(load_model("escn_export")).cuda() + base_output = base_model(data) + export_output = export_model(data) torch.set_printoptions(precision=8) - assert torch.allclose(output["energy"].cpu(), expected_energy, atol=tol) - assert torch.allclose(output["forces"].mean(0).cpu(), expected_forces, atol=tol) + assert torch.allclose(base_output["energy"], export_output["energy"], atol=tol) + assert torch.allclose(base_output["forces"].mean(0), export_output["forces"].mean(0), atol=tol) def test_escn_compiles(self): init("gloo") @@ -159,7 +149,7 @@ def test_rotation_invariance(self) -> None: # Pass it through the model. batch = data_list_collater([data, data_rotated]) - model = load_model() + model = load_model("escn_export") model.eval() out = model(batch) @@ -190,12 +180,20 @@ def test_escn_so2_conv_exports(self) -> None: } lmax, mmax = 4, 2 - mappingReduced = CoefficientMapping([lmax], [mmax]) + mappingReduced = escn_exportable.CoefficientMapping([lmax], [mmax]) shpere_channels = 128 edge_channels = 128 args=(torch.rand(680, 19, shpere_channels), torch.rand(680, edge_channels)) - so2 = SO2Block(sphere_channels=shpere_channels, hidden_channels=128, edge_channels=edge_channels, lmax_list=[lmax], mmax_list=[mmax], act=torch.nn.SiLU(), mappingReduced=mappingReduced) + so2 = escn_exportable.SO2Block( + sphere_channels=shpere_channels, + hidden_channels=128, + edge_channels=edge_channels, + lmax_list=[lmax], + mmax_list=[mmax], + act=torch.nn.SiLU(), + mappingReduced=mappingReduced + ) prog = export(so2, args=args, dynamic_shapes=dynamic_shapes1) export_out = prog.module()(*args) regular_out = so2(*args) From 7f901529a83962d27846cbcc4d921a74ec75b08e Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Thu, 29 Aug 2024 23:07:34 -0700 Subject: [PATCH 11/25] message block fails export due to SO3Rotation input --- tests/core/models/test_escn_compiles.py | 127 ++++++++++++++++++------ 1 file changed, 94 insertions(+), 33 deletions(-) diff --git a/tests/core/models/test_escn_compiles.py b/tests/core/models/test_escn_compiles.py index b96e32e152..557d95cc74 100644 --- a/tests/core/models/test_escn_compiles.py +++ b/tests/core/models/test_escn_compiles.py @@ -26,8 +26,10 @@ from fairchem.core.datasets import data_list_collater from fairchem.core.preprocessing import AtomsToGraphs from fairchem.core.common.transforms import RandomRotate +from fairchem.core.models.scn.smearing import GaussianSmearing +from fairchem.core.models.base import GraphModelMixin -from fairchem.core.models.utils.so3_utils import CoefficientMapping +from fairchem.core.models.utils.so3_utils import CoefficientMapping, SO3_Grid, SO3_Rotation from fairchem.core.models.escn import escn_exportable from torch.export import export @@ -108,36 +110,6 @@ def test_escn_baseline_cuda(self, tol=1e-5): assert torch.allclose(base_output["energy"], export_output["energy"], atol=tol) assert torch.allclose(base_output["forces"].mean(0), export_output["forces"].mean(0), atol=tol) - def test_escn_compiles(self): - init("gloo") - data = load_data() - data = data_list_collater([data]) - model = load_model() - ddp_model = DistributedDataParallel(model) - - torch._dynamo.config.optimize_ddp = False - torch._dynamo.config.assume_static_by_default = False - torch._dynamo.config.automatic_dynamic_shapes = True - # torch._dynamo.config.suppress_errors = True - - # os.environ["TORCH_LOGS"] = "+dynamo,recompiles" - torch._logging.set_logs(dynamo = logging.INFO) - # os.environ["TORCHDYNAMO_VERBOSE"] = "1" - # os.environ["TORCHDYNAMO_REPRO_AFTER"]="dynamo" - # torch._dynamo.config.verbose = True - compiled_model = torch.compile(model, dynamic=True) - torch._dynamo.config.optimize_ddp = False - # torch._dynamo.explain(model)(data) - # assert False - # torch._dynamo.reset() - # explain_output = torch._dynamo.explain(model)(data) - # print(explain_output) - - output = compiled_model(data) - # expected_energy, expected_forces = expected_energy_forces() - # assert torch.allclose(output["energy"], expected_energy) - # assert torch.allclose(output["forces"].mean(0), expected_forces) - def test_rotation_invariance(self) -> None: random.seed(1) data = load_data() @@ -165,7 +137,7 @@ def test_rotation_invariance(self) -> None: decimal=5, ) - def test_escn_so2_conv_exports(self) -> None: + def test_escn_so2_conv_exports_and_compiles(self, tol=1e-5) -> None: torch._dynamo.config.assume_static_by_default = False torch._dynamo.config.automatic_dynamic_shapes = True inp1_dim0 = Dim("inp1_dim0") @@ -197,4 +169,93 @@ def test_escn_so2_conv_exports(self) -> None: prog = export(so2, args=args, dynamic_shapes=dynamic_shapes1) export_out = prog.module()(*args) regular_out = so2(*args) - assert torch.allclose(export_out, regular_out) + assert torch.allclose(export_out, regular_out, atol=tol) + + compiled_model = torch.compile(so2, dynamic=True) + compiled_out = compiled_model(*args) + assert torch.allclose(compiled_out, regular_out, atol=tol) + + def test_escn_message_block_exports_and_compiles(self) -> None: + random.seed(1) + + sphere_channels = 128 + hidden_channels = 128 + edge_channels = 128 + lmax, mmax = 4, 2 + distance_expansion = GaussianSmearing(0.0, 8.0, int(8.0 / 0.02), 1.0) + SO3_grid = torch.nn.ModuleDict() + SO3_grid["lmax_lmax"] = SO3_Grid(lmax, lmax) + SO3_grid["lmax_mmax"] = SO3_Grid(lmax, mmax) + mappingReduced = escn_exportable.CoefficientMapping([lmax], [mmax]) + message_block = escn_exportable.MessageBlock( + layer_idx = 0, + sphere_channels = sphere_channels, + hidden_channels = hidden_channels, + edge_channels = edge_channels, + lmax_list = [lmax], + mmax_list = [mmax], + distance_expansion = distance_expansion, + max_num_elements = 90, + SO3_grid = SO3_grid, + act = torch.nn.SiLU(), + mappingReduced = mappingReduced + ) + + data = load_data() + data = data_list_collater([data]) + full_model = load_model("escn_export") + graph = full_model.generate_graph(data) + edge_rot_mat = full_model._init_edge_rot_mat( + data, graph.edge_index, graph.edge_distance_vec + ) + SO3_edge_rot = SO3_Rotation(edge_rot_mat, lmax) + + # generate inputs + batch_sizes = [34] + num_coefs = 25 + num_edges = 680 + args = [] + for b in batch_sizes: + x = torch.rand([b, num_coefs, sphere_channels]) + atom_n = torch.randint(1, 90, (b,)) + edge_d = torch.rand([num_edges]) + edge_indx = torch.randint(0, b, (2, num_edges)) + args.append((x, atom_n, edge_d, edge_indx, SO3_edge_rot)) + + # compiled_model = torch.compile(message_block, dynamic=True) + exported_prog = export(message_block, args=args[0]) + exported_output = exported_prog(*args[0]) + + # output = message_block(*args) + # compiled_output = compiled_model(*args) + + + def test_escn_compiles(self): + init("gloo") + data = load_data() + data = data_list_collater([data]) + model = load_model('escn_export') + ddp_model = DistributedDataParallel(model) + + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.assume_static_by_default = False + torch._dynamo.config.automatic_dynamic_shapes = True + # torch._dynamo.config.suppress_errors = True + + # os.environ["TORCH_LOGS"] = "+dynamo,recompiles" + # torch._logging.set_logs(dynamo = logging.INFO) + # os.environ["TORCHDYNAMO_VERBOSE"] = "1" + # os.environ["TORCHDYNAMO_REPRO_AFTER"]="dynamo" + # torch._dynamo.config.verbose = True + compiled_model = torch.compile(model, dynamic=True) + torch._dynamo.config.optimize_ddp = False + # torch._dynamo.explain(model)(data) + # assert False + # torch._dynamo.reset() + # explain_output = torch._dynamo.explain(model)(data) + # print(explain_output) + + output = compiled_model(data) + # expected_energy, expected_forces = expected_energy_forces() + # assert torch.allclose(output["energy"], expected_energy) + # assert torch.allclose(output["forces"].mean(0), expected_forces) From f56b4407ec5bd996e45541d91b9c5969a9fa99aa Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Sun, 1 Sep 2024 20:16:52 -0700 Subject: [PATCH 12/25] message block compiles and exports --- .../core/models/escn/escn_exportable.py | 34 +++++------ src/fairchem/core/models/utils/so3_utils.py | 59 +++++++++---------- tests/core/models/test_escn_compiles.py | 28 ++++++--- 3 files changed, 61 insertions(+), 60 deletions(-) diff --git a/src/fairchem/core/models/escn/escn_exportable.py b/src/fairchem/core/models/escn/escn_exportable.py index 1f01fadd59..b13c8e5cd7 100644 --- a/src/fairchem/core/models/escn/escn_exportable.py +++ b/src/fairchem/core/models/escn/escn_exportable.py @@ -25,6 +25,7 @@ CoefficientMapping, SO3_Grid, SO3_Rotation, + rotation_to_wigner ) from fairchem.core.models.escn.so3 import SO3_Embedding from fairchem.core.models.scn.sampling import CalcSpherePoints @@ -177,14 +178,6 @@ def __init__( self.SO3_grid = nn.ModuleDict() self.SO3_grid["lmax_lmax"] = SO3_Grid(self.lmax_list[0], self.lmax_list[0], resolution=resolution) self.SO3_grid["lmax_mmax"] = SO3_Grid(self.lmax_list[0], self.mmax_list[0], resolution=resolution) - # self.SO3_grid = nn.ModuleList() - # for lval in range(max(self.lmax_list) + 1): - # SO3_m_grid = nn.ModuleList() - # for m in range(max(self.lmax_list) + 1): - # SO3_m_grid.append(SO3_Grid(lval, m, resolution=resolution)) - - # self.SO3_grid.append(SO3_m_grid) - # import pdb;pdb.set_trace() self.mappingReduced = CoefficientMapping(self.lmax_list, self.mmax_list) # Initialize the blocks for each layer of the GNN @@ -257,9 +250,7 @@ def forward(self, data): edge_rot_mat = self._init_edge_rot_mat( data, graph.edge_index, graph.edge_distance_vec ) - - # Initialize the WignerD matrices and other values for spherical harmonic calculations - self.SO3_edge_rot = SO3_Rotation(edge_rot_mat, self.lmax_list[0]) + wigner = rotation_to_wigner(edge_rot_mat, 0, self.lmax_list[0]).detach() ############################################################### # Initialize node embeddings @@ -296,7 +287,7 @@ def forward(self, data): atomic_numbers, graph.edge_distance, graph.edge_index, - self.SO3_edge_rot, + wigner, ) # Residual layer for all layers past the first @@ -309,7 +300,7 @@ def forward(self, data): atomic_numbers, graph.edge_distance, graph.edge_index, - self.SO3_edge_rot, + wigner, ) x.embedding = x_message @@ -499,7 +490,7 @@ def forward( atomic_numbers: torch.Tensor, edge_distance: torch.Tensor, edge_index: torch.Tensor, - SO3_edge_rot: SO3_Rotation, + wigner: torch.Tensor, ) -> torch.Tensor: # Compute messages by performing message block x_message = self.message_block( @@ -507,7 +498,7 @@ def forward( atomic_numbers, edge_distance, edge_index, - SO3_edge_rot, + wigner, ) print(f"x_message: {x_message.mean()}") @@ -580,6 +571,7 @@ def __init__( self.mmax_list = mmax_list self.edge_channels = edge_channels self.mappingReduced = mappingReduced + self.out_mask = self.mappingReduced.coefficient_idx(self.lmax_list[0], self.mmax_list[0]) # Create edge scalar (invariant to rotations) features self.edge_block = EdgeBlock( @@ -615,7 +607,7 @@ def forward( atomic_numbers: torch.Tensor, edge_distance: torch.Tensor, edge_index: torch.Tensor, - SO3_edge_rot: SO3_Rotation, + wigner: torch.Tensor, ) -> torch.Tensor: ############################################################### # Compute messages @@ -635,8 +627,10 @@ def forward( x_target = x_target[edge_index[1, :]] # Rotate the irreps to align with the edge - x_source = SO3_edge_rot.rotate(x_source, self.lmax_list[0], self.mmax_list[0]) - x_target = SO3_edge_rot.rotate(x_target, self.lmax_list[0], self.mmax_list[0]) + x_source = torch.bmm(wigner[:, self.out_mask, :], x_source) + x_target = torch.bmm(wigner[:, self.out_mask, :], x_target) + # x_source = SO3_edge_rot.rotate(x_source, self.lmax_list[0], self.mmax_list[0]) + # x_target = SO3_edge_rot.rotate(x_target, self.lmax_list[0], self.mmax_list[0]) # Compute messages x_source = self.so2_block_source(x_source, x_edge) @@ -653,7 +647,9 @@ def forward( x_target = torch.einsum("bai,zbac->zic", from_grid_mat, x_grid) # Rotate back the irreps - x_target = SO3_edge_rot.rotate_inv(x_target, self.lmax_list[0], self.mmax_list[0]) + # x_target = SO3_edge_rot.rotate_inv(x_target, self.lmax_list[0], self.mmax_list[0]) + wigner_inv = torch.transpose(wigner, 1, 2).contiguous() + x_target = torch.bmm(wigner_inv[:, :, self.out_mask], x_target) # Compute the sum of the incoming neighboring messages for each target node new_embedding = torch.fill(x.clone(), 0) diff --git a/src/fairchem/core/models/utils/so3_utils.py b/src/fairchem/core/models/utils/so3_utils.py index 5845f48438..61e37744e0 100644 --- a/src/fairchem/core/models/utils/so3_utils.py +++ b/src/fairchem/core/models/utils/so3_utils.py @@ -48,6 +48,28 @@ def _z_rot_mat(angle: torch.Tensor, lv: int) -> torch.Tensor: M[..., inds, inds] = torch.cos(frequencies * angle[..., None]) return M +def rotation_to_wigner( + edge_rot_mat: torch.Tensor, start_lmax: int, end_lmax: int +) -> torch.Tensor: + x = edge_rot_mat @ edge_rot_mat.new_tensor([0.0, 1.0, 0.0]) + alpha, beta = o3.xyz_to_angles(x) + R = ( + o3.angles_to_matrix(alpha, beta, torch.zeros_like(alpha)).transpose(-1, -2) + @ edge_rot_mat + ) + gamma = torch.atan2(R[..., 0, 2], R[..., 0, 0]) + + size = int((end_lmax + 1) ** 2) - int((start_lmax) ** 2) + wigner = torch.zeros(len(alpha), size, size, device=edge_rot_mat.device) + start = 0 + for lmax in range(start_lmax, end_lmax + 1): + block = wigner_D(lmax, alpha, beta, gamma) + end = start + block.size()[1] + wigner[:, start:end, start:end] = block + start = end + + return wigner.detach() + class CoefficientMapping(torch.nn.Module): """ @@ -63,13 +85,11 @@ def __init__( self, lmax_list, mmax_list, - use_rotate_inv_rescale=False ): super().__init__() self.lmax_list = lmax_list self.mmax_list = mmax_list - self.use_rotate_inv_rescale = use_rotate_inv_rescale self.num_resolutions = len(lmax_list) assert (len(self.lmax_list) == 1) and (len(self.mmax_list) == 1) @@ -121,13 +141,9 @@ def __init__( self.register_buffer('l_harmonic', l_harmonic) self.register_buffer('m_harmonic', m_harmonic) self.register_buffer('m_complex', m_complex) - # self.register_buffer('res_size', res_size) self.register_buffer('to_m', to_m) - # self.register_buffer('m_size', m_size) self.pre_compute_coefficient_idx() - if self.use_rotate_inv_rescale: - self.pre_compute_rotate_inv_rescale() # Return mask containing coefficients of order m (real and imaginary parts) @@ -215,30 +231,11 @@ def pre_compute_rotate_inv_rescale(self): rotate_inv_rescale[:, start_idx : (start_idx + length), start_idx : (start_idx + length)] = rescale_factor rotate_inv_rescale = rotate_inv_rescale[:, :, mask_indices] self.register_buffer('rotate_inv_rescale_l{}_m{}'.format(l, m), rotate_inv_rescale) - - - def prepare_rotate_inv_rescale(self): - lmax = max(self.lmax_list) - rotate_inv_rescale_list = [] - for l in range(lmax + 1): - l_list = [] - for m in range(lmax + 1): - l_list.append(getattr(self, 'rotate_inv_rescale_l{}_m{}'.format(l, m), None)) - rotate_inv_rescale_list.append(l_list) - return rotate_inv_rescale_list - - - # Return the re-scaling for rotating back to original frame - # this is required since we only use a subset of m components for SO(2) convolution - def get_rotate_inv_rescale(self, lmax, mmax): - temp = self.prepare_rotate_inv_rescale() - return temp[lmax][mmax] - def __repr__(self): return f"{self.__class__.__name__}(lmax_list={self.lmax_list}, mmax_list={self.mmax_list})" -class SO3_Rotation(torch.nn.Module): +class SO3_Rotation: """ Helper functions for Wigner-D rotations @@ -262,12 +259,10 @@ def __init__( self.wigner = self.wigner.detach() self.wigner_inv = self.wigner_inv.detach() - self.set_lmax(lmax) - - # Initialize coefficients for reshape l<-->m - def set_lmax(self, lmax) -> None: self.lmax = lmax - self.mapping = CoefficientMapping([self.lmax], [self.lmax], use_rotate_inv_rescale=True) + import pdb;pdb.set_trace() + self.mapping = CoefficientMapping([self.lmax], [self.lmax]) + # Rotate the embedding def rotate(self, embedding, out_lmax, out_mmax) -> torch.Tensor: @@ -286,7 +281,7 @@ def rotate_inv(self, embedding, in_lmax, in_mmax) -> torch.Tensor: def RotationToWignerDMatrix( self, edge_rot_mat: torch.Tensor, start_lmax: int, end_lmax: int ) -> torch.Tensor: - x = edge_rot_mat @ edge_rot_mat.new_tensor([0.0, 1.0, 0.0]) + x = edge_rot_mat[:,:,1] alpha, beta = o3.xyz_to_angles(x) R = ( o3.angles_to_matrix(alpha, beta, torch.zeros_like(alpha)).transpose(-1, -2) diff --git a/tests/core/models/test_escn_compiles.py b/tests/core/models/test_escn_compiles.py index 557d95cc74..c0458fbb9f 100644 --- a/tests/core/models/test_escn_compiles.py +++ b/tests/core/models/test_escn_compiles.py @@ -29,7 +29,7 @@ from fairchem.core.models.scn.smearing import GaussianSmearing from fairchem.core.models.base import GraphModelMixin -from fairchem.core.models.utils.so3_utils import CoefficientMapping, SO3_Grid, SO3_Rotation +from fairchem.core.models.utils.so3_utils import CoefficientMapping, SO3_Grid, rotation_to_wigner from fairchem.core.models.escn import escn_exportable from torch.export import export @@ -175,7 +175,7 @@ def test_escn_so2_conv_exports_and_compiles(self, tol=1e-5) -> None: compiled_out = compiled_model(*args) assert torch.allclose(compiled_out, regular_out, atol=tol) - def test_escn_message_block_exports_and_compiles(self) -> None: + def test_escn_message_block_exports_and_compiles(self, tol=1e-5) -> None: random.seed(1) sphere_channels = 128 @@ -186,7 +186,7 @@ def test_escn_message_block_exports_and_compiles(self) -> None: SO3_grid = torch.nn.ModuleDict() SO3_grid["lmax_lmax"] = SO3_Grid(lmax, lmax) SO3_grid["lmax_mmax"] = SO3_Grid(lmax, mmax) - mappingReduced = escn_exportable.CoefficientMapping([lmax], [mmax]) + mappingReduced = CoefficientMapping([lmax], [mmax]) message_block = escn_exportable.MessageBlock( layer_idx = 0, sphere_channels = sphere_channels, @@ -208,7 +208,7 @@ def test_escn_message_block_exports_and_compiles(self) -> None: edge_rot_mat = full_model._init_edge_rot_mat( data, graph.edge_index, graph.edge_distance_vec ) - SO3_edge_rot = SO3_Rotation(edge_rot_mat, lmax) + wigner = rotation_to_wigner(edge_rot_mat, 0, lmax).detach() # generate inputs batch_sizes = [34] @@ -220,15 +220,25 @@ def test_escn_message_block_exports_and_compiles(self) -> None: atom_n = torch.randint(1, 90, (b,)) edge_d = torch.rand([num_edges]) edge_indx = torch.randint(0, b, (2, num_edges)) - args.append((x, atom_n, edge_d, edge_indx, SO3_edge_rot)) + args.append((x, atom_n, edge_d, edge_indx, wigner)) + + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.assume_static_by_default = False + torch._dynamo.config.automatic_dynamic_shapes = True + torch._dynamo.config.verbose = True + # torch._logging.set_logs(dynamo = logging.INFO) + # torch._dynamo.reset() + # explain_output = torch._dynamo.explain(message_block)(*args[0]) + # print(explain_output) + compiled_model = torch.compile(message_block, dynamic=True) + compiled_output = compiled_model(*args[0]) - # compiled_model = torch.compile(message_block, dynamic=True) exported_prog = export(message_block, args=args[0]) exported_output = exported_prog(*args[0]) - # output = message_block(*args) - # compiled_output = compiled_model(*args) - + regular_out = message_block(*args[0]) + assert torch.allclose(compiled_output, regular_out, atol=tol) + assert torch.allclose(exported_output, regular_out, atol=tol) def test_escn_compiles(self): init("gloo") From 5f223a365cd690127d28017450778955829fe4c9 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Sun, 1 Sep 2024 21:45:26 -0700 Subject: [PATCH 13/25] layer block compiles and exports --- .../core/models/escn/escn_exportable.py | 3 - tests/core/models/test_escn_compiles.py | 70 +++++++++++++++++-- 2 files changed, 66 insertions(+), 7 deletions(-) diff --git a/src/fairchem/core/models/escn/escn_exportable.py b/src/fairchem/core/models/escn/escn_exportable.py index b13c8e5cd7..97702873a5 100644 --- a/src/fairchem/core/models/escn/escn_exportable.py +++ b/src/fairchem/core/models/escn/escn_exportable.py @@ -500,7 +500,6 @@ def forward( edge_index, wigner, ) - print(f"x_message: {x_message.mean()}") # Compute point-wise spherical non-linearity on aggregated messages @@ -517,14 +516,12 @@ def forward( x_grid = self.act(self.fc1_sphere(x_grid)) x_grid = self.act(self.fc2_sphere(x_grid)) x_grid = self.fc3_sphere(x_grid) - print(f"x_grid: {x_grid.mean()}") # Project back to spherical harmonic coefficients # x_message._from_grid(x_grid, self.SO3_grid["lmax_lmax"]) from_grid_mat = self.SO3_grid["lmax_lmax"].from_grid_mat[:, :, self.SO3_grid["lmax_lmax"].mapping.coefficient_idx(self.lmax_list[0], self.lmax_list[0])] x_message_final = torch.einsum("bai,zbac->zic", from_grid_mat, x_grid) - print(f"x_message_final: {x_message_final.mean()}") # Return aggregated messages return x_message_final diff --git a/tests/core/models/test_escn_compiles.py b/tests/core/models/test_escn_compiles.py index c0458fbb9f..8e53eb7028 100644 --- a/tests/core/models/test_escn_compiles.py +++ b/tests/core/models/test_escn_compiles.py @@ -226,10 +226,6 @@ def test_escn_message_block_exports_and_compiles(self, tol=1e-5) -> None: torch._dynamo.config.assume_static_by_default = False torch._dynamo.config.automatic_dynamic_shapes = True torch._dynamo.config.verbose = True - # torch._logging.set_logs(dynamo = logging.INFO) - # torch._dynamo.reset() - # explain_output = torch._dynamo.explain(message_block)(*args[0]) - # print(explain_output) compiled_model = torch.compile(message_block, dynamic=True) compiled_output = compiled_model(*args[0]) @@ -240,6 +236,72 @@ def test_escn_message_block_exports_and_compiles(self, tol=1e-5) -> None: assert torch.allclose(compiled_output, regular_out, atol=tol) assert torch.allclose(exported_output, regular_out, atol=tol) + def test_escn_layer_block_exports_and_compiles(self, tol=1e-5) -> None: + random.seed(1) + + sphere_channels = 128 + hidden_channels = 128 + edge_channels = 128 + lmax, mmax = 4, 2 + distance_expansion = GaussianSmearing(0.0, 8.0, int(8.0 / 0.02), 1.0) + SO3_grid = torch.nn.ModuleDict() + SO3_grid["lmax_lmax"] = SO3_Grid(lmax, lmax) + SO3_grid["lmax_mmax"] = SO3_Grid(lmax, mmax) + mappingReduced = CoefficientMapping([lmax], [mmax]) + layer_block = escn_exportable.LayerBlock( + layer_idx = 0, + sphere_channels = sphere_channels, + hidden_channels = hidden_channels, + edge_channels = edge_channels, + lmax_list = [lmax], + mmax_list = [mmax], + distance_expansion = distance_expansion, + max_num_elements = 90, + SO3_grid = SO3_grid, + act = torch.nn.SiLU(), + mappingReduced = mappingReduced + ) + + # generate inputs + batch_sizes = [34, 35, 35] + num_edges = [680, 700, 680] + num_coefs = 25 + run_args = [] + for b,edges in zip(batch_sizes, num_edges): + x = torch.rand([b, num_coefs, sphere_channels]) + atom_n = torch.randint(1, 90, (b,)) + edge_d = torch.rand([edges]) + edge_indx = torch.randint(0, b, (2, edges)) + wigner = torch.rand([edges, num_coefs, num_coefs]) + run_args.append((x, atom_n, edge_d, edge_indx, wigner)) + + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.assume_static_by_default = False + torch._dynamo.config.automatic_dynamic_shapes = True + torch._dynamo.config.verbose = True + # torch._logging.set_logs(dynamo = logging.INFO) + # torch._dynamo.reset() + # explain_output = torch._dynamo.explain(message_block)(*args[0]) + # print(explain_output) + + batch_dim = Dim("batch_dim") + edges_dim = Dim("edges_dim") + dynamic_shapes1 = { + "x": {0: batch_dim, 1: None, 2: None}, + "atomic_numbers": {0: batch_dim}, + "edge_distance": {0: edges_dim}, + "edge_index": {0: None, 1: edges_dim}, + "wigner": {0: edges_dim, 1: None, 2: None} + } + exported_prog = export(layer_block, args=run_args[0], dynamic_shapes=dynamic_shapes1) + for run_arg in run_args: + exported_output = exported_prog(*run_arg) + compiled_model = torch.compile(layer_block, dynamic=True) + compiled_output = compiled_model(*run_arg) + regular_out = layer_block(*run_arg) + assert torch.allclose(compiled_output, regular_out, atol=tol) + assert torch.allclose(exported_output, regular_out, atol=tol) + def test_escn_compiles(self): init("gloo") data = load_data() From cb679ddd5c36cfd9179cb812f830f5b3dbcd9c20 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Sun, 1 Sep 2024 23:07:18 -0700 Subject: [PATCH 14/25] remove most of lmax_list and mmax_list --- .../core/models/escn/escn_exportable.py | 204 +++++++----------- tests/core/models/test_escn_compiles.py | 13 +- 2 files changed, 81 insertions(+), 136 deletions(-) diff --git a/src/fairchem/core/models/escn/escn_exportable.py b/src/fairchem/core/models/escn/escn_exportable.py index 97702873a5..dfeb4f0b23 100644 --- a/src/fairchem/core/models/escn/escn_exportable.py +++ b/src/fairchem/core/models/escn/escn_exportable.py @@ -27,7 +27,6 @@ SO3_Rotation, rotation_to_wigner ) -from fairchem.core.models.escn.so3 import SO3_Embedding from fairchem.core.models.scn.sampling import CalcSpherePoints from fairchem.core.models.scn.smearing import ( GaussianSmearing, @@ -56,8 +55,8 @@ class eSCN(nn.Module, GraphModelMixin): max_num_elements (int): Maximum atomic number num_layers (int): Number of layers in the GNN - lmax_list (int): List of maximum degree of the spherical harmonics (1 to 10) - mmax_list (int): List of maximum order of the spherical harmonics (0 to lmax) + lmax (int): maximum degree of the spherical harmonics (1 to 10) + mmax (int): maximum order of the spherical harmonics (0 to lmax) sphere_channels (int): Number of spherical channels (one set per resolution) hidden_channels (int): Number of hidden units in message passing num_sphere_samples (int): Number of samples used to approximate the integration of the sphere in the output blocks @@ -78,8 +77,8 @@ def __init__( cutoff: float = 8.0, max_num_elements: int = 90, num_layers: int = 8, - lmax_list: list[int] | None = None, - mmax_list: list[int] | None = None, + lmax_list: List[int] = [4], # list of 1, for backward compat only right now, + mmax_list: List[int] = [2], # list of 1, for backward compat only right now, sphere_channels: int = 128, hidden_channels: int = 256, edge_channels: int = 128, @@ -90,10 +89,6 @@ def __init__( show_timing_info: bool = False, resolution: int | None = None, ) -> None: - if mmax_list is None: - mmax_list = [2] - if lmax_list is None: - lmax_list = [6] super().__init__() import sys @@ -120,8 +115,9 @@ def __init__( self.grad_forces = False self.lmax_list = lmax_list self.mmax_list = mmax_list - self.num_resolutions: int = len(self.lmax_list) - self.sphere_channels_all: int = self.num_resolutions * self.sphere_channels + assert len(self.lmax_list) == 1 and len(self.mmax_list) == 1 + self.lmax = lmax_list[0] + self.mmax = mmax_list[0] self.basis_width_scalar = basis_width_scalar self.distance_function = distance_function @@ -133,7 +129,7 @@ def __init__( # Weights for message initialization self.sphere_embedding = nn.Embedding( - self.max_num_elements, self.sphere_channels_all + self.max_num_elements, self.sphere_channels ) # Initialize the function used to measure the distances between atoms @@ -174,11 +170,10 @@ def __init__( ) # Initialize the transformations between spherical and grid representations - assert self.num_resolutions == 1, "Only one resolution is supported" self.SO3_grid = nn.ModuleDict() - self.SO3_grid["lmax_lmax"] = SO3_Grid(self.lmax_list[0], self.lmax_list[0], resolution=resolution) - self.SO3_grid["lmax_mmax"] = SO3_Grid(self.lmax_list[0], self.mmax_list[0], resolution=resolution) - self.mappingReduced = CoefficientMapping(self.lmax_list, self.mmax_list) + self.SO3_grid["lmax_lmax"] = SO3_Grid(self.lmax, self.lmax, resolution=resolution) + self.SO3_grid["lmax_mmax"] = SO3_Grid(self.lmax, self.mmax, resolution=resolution) + self.mappingReduced = CoefficientMapping([self.lmax], [self.mmax]) # Initialize the blocks for each layer of the GNN self.layer_blocks = nn.ModuleList() @@ -188,8 +183,8 @@ def __init__( self.sphere_channels, self.hidden_channels, self.edge_channels, - self.lmax_list, - self.mmax_list, + self.lmax, + self.mmax, self.distance_expansion, self.max_num_elements, self.SO3_grid, @@ -200,11 +195,11 @@ def __init__( # Output blocks for energy and forces self.energy_block = EnergyBlock( - self.sphere_channels_all, self.num_sphere_samples, self.act + self.sphere_channels, self.num_sphere_samples, self.act ) if self.regress_forces: self.force_block = ForceBlock( - self.sphere_channels_all, self.num_sphere_samples, self.act + self.sphere_channels, self.num_sphere_samples, self.act ) # Create a roughly evenly distributed point sampling of the sphere for the output blocks @@ -213,19 +208,14 @@ def __init__( ) # For each spherical point, compute the spherical harmonic coefficient weights - sphharm_weights: list[nn.Parameter] = [] - for i in range(self.num_resolutions): - sphharm_weights.append( - nn.Parameter( - o3.spherical_harmonics( - torch.arange(0, self.lmax_list[i] + 1).tolist(), - self.sphere_points, - False, - ), - requires_grad=False, - ) - ) - self.sphharm_weights = nn.ParameterList(sphharm_weights) + self.sphharm_weights: nn.Parameter = nn.Parameter( + o3.spherical_harmonics( + torch.arange(0, self.lmax + 1).tolist(), + self.sphere_points, + False, + ), + requires_grad=False, + ) @conditional_grad(torch.enable_grad()) @@ -250,36 +240,25 @@ def forward(self, data): edge_rot_mat = self._init_edge_rot_mat( data, graph.edge_index, graph.edge_distance_vec ) - wigner = rotation_to_wigner(edge_rot_mat, 0, self.lmax_list[0]).detach() + wigner = rotation_to_wigner(edge_rot_mat, 0, self.lmax).detach() ############################################################### # Initialize node embeddings ############################################################### # Init per node representations using an atomic number based embedding - offset = 0 - x = SO3_Embedding( + x_message = torch.zeros( num_atoms, - self.lmax_list, + int((self.lmax + 1) ** 2), self.sphere_channels, - device, - self.dtype, + device=device, + dtype=self.dtype, ) - - offset_res = 0 - offset = 0 - # Initialize the l=0,m=0 coefficients for each resolution - for i in range(self.num_resolutions): - x.embedding[:, offset_res, :] = self.sphere_embedding(atomic_numbers)[ - :, offset : offset + self.sphere_channels - ] - offset = offset + self.sphere_channels - offset_res = offset_res + int((self.lmax_list[i] + 1) ** 2) + x_message[:, 0, :] = self.sphere_embedding(atomic_numbers) ############################################################### # Update spherical node embeddings ############################################################### - x_message = x.embedding for i in range(self.num_layers): if i > 0: x_message_new = self.layer_blocks[i]( @@ -302,29 +281,10 @@ def forward(self, data): graph.edge_index, wigner, ) - x.embedding = x_message # Sample the spherical channels (node embeddings) at evenly distributed points on the sphere. # These values are fed into the output blocks. - x_pt = torch.tensor([], device=device) - offset = 0 - # Compute the embedding values at every sampled point on the sphere - for i in range(self.num_resolutions): - num_coefficients = int((x.lmax_list[i] + 1) ** 2) - x_pt = torch.cat( - [ - x_pt, - torch.einsum( - "abc, pb->apc", - x.embedding[:, offset : offset + num_coefficients], - self.sphharm_weights[i], - ).contiguous(), - ], - dim=2, - ) - offset = offset + num_coefficients - - x_pt = x_pt.view(-1, self.sphere_channels_all) + x_pt = torch.einsum("abc, pb->apc", x_message, self.sphharm_weights).contiguous() ############################################################### # Energy estimation @@ -423,8 +383,8 @@ class LayerBlock(torch.nn.Module): sphere_channels (int): Number of spherical channels hidden_channels (int): Number of hidden channels used during the SO(2) conv edge_channels (int): Size of invariant edge embedding - lmax_list (list:int): List of degrees (l) for each resolution - mmax_list (list:int): List of orders (m) for each resolution + lmax (int) degrees (l) for each resolution + mmax (int): orders (m) for each resolution distance_expansion (func): Function used to compute distance embedding max_num_elements (int): Maximum number of atomic numbers SO3_grid (SO3_grid): Class used to convert from grid the spherical harmonic representations @@ -437,8 +397,8 @@ def __init__( sphere_channels: int, hidden_channels: int, edge_channels: int, - lmax_list: list[int], - mmax_list: list[int], + lmax: int, + mmax: int, distance_expansion, max_num_elements: int, SO3_grid: SO3_Grid, @@ -448,11 +408,9 @@ def __init__( super().__init__() self.layer_idx = layer_idx self.act = act - self.lmax_list = lmax_list - self.mmax_list = mmax_list - self.num_resolutions = len(lmax_list) + self.lmax = lmax + self.mmax = mmax self.sphere_channels = sphere_channels - self.sphere_channels_all = self.num_resolutions * self.sphere_channels self.SO3_grid = SO3_grid self.mappingReduced = mappingReduced @@ -462,8 +420,8 @@ def __init__( self.sphere_channels, hidden_channels, edge_channels, - self.lmax_list, - self.mmax_list, + self.lmax, + self.mmax, distance_expansion, max_num_elements, self.SO3_grid, @@ -473,15 +431,15 @@ def __init__( # Non-linear point-wise comvolution for the aggregated messages self.fc1_sphere = nn.Linear( - 2 * self.sphere_channels_all, self.sphere_channels_all, bias=False + 2 * self.sphere_channels, self.sphere_channels, bias=False ) self.fc2_sphere = nn.Linear( - self.sphere_channels_all, self.sphere_channels_all, bias=False + self.sphere_channels, self.sphere_channels, bias=False ) self.fc3_sphere = nn.Linear( - self.sphere_channels_all, self.sphere_channels_all, bias=False + self.sphere_channels, self.sphere_channels, bias=False ) def forward( @@ -505,7 +463,7 @@ def forward( # Project to grid # x_grid_message = x_message.to_grid(self.SO3_grid["lmax_lmax"]) - to_grid_mat = self.SO3_grid["lmax_lmax"].to_grid_mat[:, :, self.SO3_grid["lmax_lmax"].mapping.coefficient_idx(self.lmax_list[0], self.lmax_list[0])] + to_grid_mat = self.SO3_grid["lmax_lmax"].to_grid_mat[:, :, self.SO3_grid["lmax_lmax"].mapping.coefficient_idx(self.lmax, self.lmax)] x_grid_message = torch.einsum("bai,zic->zbac", to_grid_mat, x_message) # x_grid = x.to_grid(self.SO3_grid["lmax_lmax"]) @@ -519,7 +477,7 @@ def forward( # Project back to spherical harmonic coefficients # x_message._from_grid(x_grid, self.SO3_grid["lmax_lmax"]) - from_grid_mat = self.SO3_grid["lmax_lmax"].from_grid_mat[:, :, self.SO3_grid["lmax_lmax"].mapping.coefficient_idx(self.lmax_list[0], self.lmax_list[0])] + from_grid_mat = self.SO3_grid["lmax_lmax"].from_grid_mat[:, :, self.SO3_grid["lmax_lmax"].mapping.coefficient_idx(self.lmax, self.lmax)] x_message_final = torch.einsum("bai,zbac->zic", from_grid_mat, x_grid) # Return aggregated messages @@ -535,8 +493,8 @@ class MessageBlock(torch.nn.Module): sphere_channels (int): Number of spherical channels hidden_channels (int): Number of hidden channels used during the SO(2) conv edge_channels (int): Size of invariant edge embedding - lmax_list (list:int): List of degrees (l) for each resolution - mmax_list (list:int): List of orders (m) for each resolution + lmax (int): degrees (l) for each resolution + mmax (int): orders (m) for each resolution distance_expansion (func): Function used to compute distance embedding max_num_elements (int): Maximum number of atomic numbers SO3_grid (SO3_grid): Class used to convert from grid the spherical harmonic representations @@ -549,8 +507,8 @@ def __init__( sphere_channels: int, hidden_channels: int, edge_channels: int, - lmax_list: list[int], - mmax_list: list[int], + lmax: int, + mmax: int, distance_expansion, max_num_elements: int, SO3_grid: SO3_Grid, @@ -563,12 +521,11 @@ def __init__( self.hidden_channels = hidden_channels self.sphere_channels = sphere_channels self.SO3_grid = SO3_grid - self.num_resolutions = len(lmax_list) - self.lmax_list = lmax_list - self.mmax_list = mmax_list + self.lmax = lmax + self.mmax = mmax self.edge_channels = edge_channels self.mappingReduced = mappingReduced - self.out_mask = self.mappingReduced.coefficient_idx(self.lmax_list[0], self.mmax_list[0]) + self.out_mask = self.mappingReduced.coefficient_idx(self.lmax, self.mmax) # Create edge scalar (invariant to rotations) features self.edge_block = EdgeBlock( @@ -583,8 +540,8 @@ def __init__( self.sphere_channels, self.hidden_channels, self.edge_channels, - self.lmax_list, - self.mmax_list, + self.lmax, + self.mmax, self.act, self.mappingReduced ) @@ -592,8 +549,8 @@ def __init__( self.sphere_channels, self.hidden_channels, self.edge_channels, - self.lmax_list, - self.mmax_list, + self.lmax, + self.mmax, self.act, self.mappingReduced ) @@ -626,8 +583,6 @@ def forward( # Rotate the irreps to align with the edge x_source = torch.bmm(wigner[:, self.out_mask, :], x_source) x_target = torch.bmm(wigner[:, self.out_mask, :], x_target) - # x_source = SO3_edge_rot.rotate(x_source, self.lmax_list[0], self.mmax_list[0]) - # x_target = SO3_edge_rot.rotate(x_target, self.lmax_list[0], self.mmax_list[0]) # Compute messages x_source = self.so2_block_source(x_source, x_edge) @@ -644,14 +599,12 @@ def forward( x_target = torch.einsum("bai,zbac->zic", from_grid_mat, x_grid) # Rotate back the irreps - # x_target = SO3_edge_rot.rotate_inv(x_target, self.lmax_list[0], self.mmax_list[0]) wigner_inv = torch.transpose(wigner, 1, 2).contiguous() x_target = torch.bmm(wigner_inv[:, :, self.out_mask], x_target) # Compute the sum of the incoming neighboring messages for each target node new_embedding = torch.fill(x.clone(), 0) new_embedding.index_add_(0, edge_index[1], x_target) - # x_target._reduce_edge(edge_index[1], len(x.embedding)) return new_embedding @@ -664,8 +617,8 @@ class SO2Block(torch.nn.Module): sphere_channels (int): Number of spherical channels hidden_channels (int): Number of hidden channels used during the SO(2) conv edge_channels (int): Size of invariant edge embedding - lmax_list (list:int): List of degrees (l) for each resolution - mmax_list (list:int): List of orders (m) for each resolution + lmax (int): degrees (l) for each resolution + mmax (int): orders (m) for each resolution act (function): Non-linear activation function """ @@ -674,24 +627,20 @@ def __init__( sphere_channels: int, hidden_channels: int, edge_channels: int, - lmax_list: list[int], - mmax_list: list[int], + lmax: int, + mmax: int, act, mappingReduced ) -> None: super().__init__() self.sphere_channels = sphere_channels self.hidden_channels = hidden_channels - self.lmax_list = lmax_list - self.mmax_list = mmax_list - self.num_resolutions: int = len(lmax_list) + self.lmax = lmax + self.mmax = mmax self.act = act self.mappingReduced = mappingReduced - num_channels_m0 = 0 - for i in range(self.num_resolutions): - num_coefficents = self.lmax_list[i] + 1 - num_channels_m0 = num_channels_m0 + num_coefficents * self.sphere_channels + num_channels_m0 = (self.lmax + 1) * self.sphere_channels # SO(2) convolution for m=0 self.fc1_dist0 = nn.Linear(edge_channels, self.hidden_channels) @@ -700,14 +649,14 @@ def __init__( # SO(2) convolution for non-zero m self.so2_conv = nn.ModuleList() - for m in range(1, max(self.mmax_list) + 1): + for m in range(1, self.mmax + 1): so2_conv = SO2Conv( m, self.sphere_channels, self.hidden_channels, edge_channels, - self.lmax_list, - self.mmax_list, + self.lmax, + self.mmax, self.act, ) self.so2_conv.append(so2_conv) @@ -740,7 +689,7 @@ def forward( # Compute the values for the m > 0 coefficients offset = self.mappingReduced.m_size[0] - for m in range(1, max(self.mmax_list) + 1): + for m in range(1, self.mmax + 1): # Get the m order coefficients x_m = x[ :, offset : offset + 2 * self.mappingReduced.m_size[m] @@ -769,8 +718,8 @@ class SO2Conv(torch.nn.Module): sphere_channels (int): Number of spherical channels hidden_channels (int): Number of hidden channels used during the SO(2) conv edge_channels (int): Size of invariant edge embedding - lmax_list (list:int): List of degrees (l) for each resolution - mmax_list (list:int): List of orders (m) for each resolution + lmax (int): degrees (l) for each resolution + mmax (int): orders (m) for each resolution act (function): Non-linear activation function """ @@ -780,26 +729,23 @@ def __init__( sphere_channels: int, hidden_channels: int, edge_channels: int, - lmax_list: list[int], - mmax_list: list[int], + lmax: int, + mmax: int, act, ) -> None: super().__init__() self.hidden_channels = hidden_channels - self.lmax_list = lmax_list - self.mmax_list = mmax_list + self.lmax = lmax + self.mmax = mmax self.sphere_channels = sphere_channels - self.num_resolutions: int = len(self.lmax_list) self.m = m self.act = act - num_channels = 0 - for i in range(self.num_resolutions): - num_coefficents = 0 - if self.mmax_list[i] >= m: - num_coefficents = self.lmax_list[i] - m + 1 + num_coefficents = 0 + if self.mmax >= m: + num_coefficents = self.lmax - m + 1 - num_channels = num_channels + num_coefficents * self.sphere_channels + num_channels = num_coefficents * self.sphere_channels assert num_channels > 0 diff --git a/tests/core/models/test_escn_compiles.py b/tests/core/models/test_escn_compiles.py index 8e53eb7028..e653ef4051 100644 --- a/tests/core/models/test_escn_compiles.py +++ b/tests/core/models/test_escn_compiles.py @@ -161,8 +161,8 @@ def test_escn_so2_conv_exports_and_compiles(self, tol=1e-5) -> None: sphere_channels=shpere_channels, hidden_channels=128, edge_channels=edge_channels, - lmax_list=[lmax], - mmax_list=[mmax], + lmax=lmax, + mmax=mmax, act=torch.nn.SiLU(), mappingReduced=mappingReduced ) @@ -192,8 +192,8 @@ def test_escn_message_block_exports_and_compiles(self, tol=1e-5) -> None: sphere_channels = sphere_channels, hidden_channels = hidden_channels, edge_channels = edge_channels, - lmax_list = [lmax], - mmax_list = [mmax], + lmax = lmax, + mmax = mmax, distance_expansion = distance_expansion, max_num_elements = 90, SO3_grid = SO3_grid, @@ -253,8 +253,8 @@ def test_escn_layer_block_exports_and_compiles(self, tol=1e-5) -> None: sphere_channels = sphere_channels, hidden_channels = hidden_channels, edge_channels = edge_channels, - lmax_list = [lmax], - mmax_list = [mmax], + lmax = lmax, + mmax = mmax, distance_expansion = distance_expansion, max_num_elements = 90, SO3_grid = SO3_grid, @@ -320,7 +320,6 @@ def test_escn_compiles(self): # os.environ["TORCHDYNAMO_REPRO_AFTER"]="dynamo" # torch._dynamo.config.verbose = True compiled_model = torch.compile(model, dynamic=True) - torch._dynamo.config.optimize_ddp = False # torch._dynamo.explain(model)(data) # assert False # torch._dynamo.reset() From 2250fa6d01071d408b9f9e7989ee9a44f0e626d0 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Mon, 2 Sep 2024 14:49:46 -0700 Subject: [PATCH 15/25] remove eqv2 stuff from this branch --- .../core/models/equiformer_v2/so2_ops.py | 140 ----------- src/fairchem/core/models/equiformer_v2/so3.py | 222 ++++++++---------- tests/core/models/test_eqv2_compiles.py | 194 --------------- 3 files changed, 93 insertions(+), 463 deletions(-) delete mode 100644 tests/core/models/test_eqv2_compiles.py diff --git a/src/fairchem/core/models/equiformer_v2/so2_ops.py b/src/fairchem/core/models/equiformer_v2/so2_ops.py index 1848fae2cb..71666284e8 100644 --- a/src/fairchem/core/models/equiformer_v2/so2_ops.py +++ b/src/fairchem/core/models/equiformer_v2/so2_ops.py @@ -139,7 +139,6 @@ def __init__( self.rad_func = RadialFunction(self.edge_channels_list) def forward(self, x, x_edge): - num_edges = len(x_edge) out = [] @@ -353,142 +352,3 @@ def forward(self, x, x_edge): out_embedding._l_primary(self.mappingReduced) return out_embedding - -class SO2_Convolution_Exportable(torch.nn.Module): - """ - SO(2) Block: Perform SO(2) convolutions for all m (orders) - - Args: - sphere_channels (int): Number of spherical channels - m_output_channels (int): Number of output channels used during the SO(2) conv - lmax_list (list:int): List of degrees (l) for each resolution - mmax_list (list:int): List of orders (m) for each resolution - mappingReduced (CoefficientMappingModule): Used to extract a subset of m components - internal_weights (bool): If True, not using radial function to multiply inputs features - edge_channels_list (list:int): List of sizes of invariant edge embedding. For example, [input_channels, hidden_channels, hidden_channels]. - extra_m0_output_channels (int): If not None, return `out_embedding` (SO3_Embedding) and `extra_m0_features` (Tensor). - """ - - def __init__( - self, - sphere_channels: int, - m_output_channels: int, - lmax_list: list[int], - mmax_list: list[int], - mappingReduced, - internal_weights: bool = True, - edge_channels_list: list[int] | None = None, - extra_m0_output_channels: int | None = None, - ): - super().__init__() - self.sphere_channels = sphere_channels - self.m_output_channels = m_output_channels - self.lmax_list = lmax_list - self.mmax_list = mmax_list - self.mappingReduced = mappingReduced - self.num_resolutions = len(lmax_list) - self.internal_weights = internal_weights - self.edge_channels_list = copy.deepcopy(edge_channels_list) - self.extra_m0_output_channels = extra_m0_output_channels - - num_channels_rad = 0 # for radial function - - num_channels_m0 = 0 - for i in range(self.num_resolutions): - num_coefficients = self.lmax_list[i] + 1 - num_channels_m0 = num_channels_m0 + num_coefficients * self.sphere_channels - - # SO(2) convolution for m = 0 - m0_output_channels = self.m_output_channels * ( - num_channels_m0 // self.sphere_channels - ) - if self.extra_m0_output_channels is not None: - m0_output_channels = m0_output_channels + self.extra_m0_output_channels - self.fc_m0 = Linear(num_channels_m0, m0_output_channels) - num_channels_rad = num_channels_rad + self.fc_m0.in_features - - # SO(2) convolution for non-zero m - self.so2_m_conv = nn.ModuleList() - for m in range(1, max(self.mmax_list) + 1): - self.so2_m_conv.append( - SO2_m_Convolution( - m, - self.sphere_channels, - self.m_output_channels, - self.lmax_list, - self.mmax_list, - ) - ) - num_channels_rad = num_channels_rad + self.so2_m_conv[-1].fc.in_features - - # Embedding function of distance - self.rad_func = None - if not self.internal_weights: - assert self.edge_channels_list is not None - self.edge_channels_list.append(int(num_channels_rad)) - self.rad_func = RadialFunction(self.edge_channels_list) - - def forward(self, x_emb, x_edge): - # x_emb: [num_edges, num_sh_coefs, num_features] - # x_edge: [num_edges, num_edge_features] - - num_edges = x_edge.shape[0] - out = [] - # torch export does not inputs based on a buffered tensor - m_size = self.mappingReduced.m_size - - # Reshape the spherical harmonics based on m (order), equivalent to x._m_primary - x_emb = torch.einsum("nac, ba -> nbc", x_emb, self.mappingReduced.to_m) - - # radial function - if self.rad_func is not None: - x_edge = self.rad_func(x_edge) - offset_rad = 0 - - # Compute m=0 coefficients separately since they only have real values (no imaginary) - x_0 = x_emb.narrow(1, 0, m_size[0]) - x_0 = x_0.reshape(x_edge.shape[0], -1) - if self.rad_func is not None: - x_edge_0 = x_edge.narrow(1, 0, self.fc_m0.in_features) - x_0 = x_0 * x_edge_0 - x_0 = self.fc_m0(x_0) - - x_0_extra = None - # extract extra m0 features - if self.extra_m0_output_channels is not None: - x_0_extra = x_0.narrow(-1, 0, self.extra_m0_output_channels) - x_0 = x_0.narrow( - -1, - self.extra_m0_output_channels, - (self.fc_m0.out_features - self.extra_m0_output_channels), - ) - - x_0 = x_0.view(num_edges, -1, self.m_output_channels) - out.append(x_0) - offset_rad = offset_rad + self.fc_m0.in_features - - # Compute the values for the m > 0 coefficients - offset = m_size[0] - for m in range(1, max(self.mmax_list) + 1): - # Get the m order coefficients - x_m = x_emb.narrow(1, offset, 2 * m_size[m]) - x_m = x_m.reshape(num_edges, 2, -1) - - # Perform SO(2) convolution - if self.rad_func is not None: - x_edge_m = x_edge.narrow( - 1, offset_rad, self.so2_m_conv[m - 1].fc.in_features - ) - x_edge_m = x_edge_m.reshape( - num_edges, 1, self.so2_m_conv[m - 1].fc.in_features - ) - x_m = x_m * x_edge_m - x_m = self.so2_m_conv[m - 1](x_m) - x_m = x_m.view(num_edges, -1, self.m_output_channels) - out.append(x_m) - offset = offset + 2 * m_size[m] - offset_rad = offset_rad + self.so2_m_conv[m - 1].fc.in_features - - out = torch.cat(out, dim=1) - out = torch.einsum("nac, ab -> nbc", out, self.mappingReduced.to_m) - return out diff --git a/src/fairchem/core/models/equiformer_v2/so3.py b/src/fairchem/core/models/equiformer_v2/so3.py index 61c25f9437..a3d58586e0 100644 --- a/src/fairchem/core/models/equiformer_v2/so3.py +++ b/src/fairchem/core/models/equiformer_v2/so3.py @@ -30,56 +30,52 @@ class CoefficientMappingModule(torch.nn.Module): """ - Helper module for coefficients used to reshape l <--> m and to get coefficients of specific degree or order + Helper module for coefficients used to reshape lval <--> m and to get coefficients of specific degree or order Args: lmax_list (list:int): List of maximum degree of the spherical harmonics mmax_list (list:int): List of maximum order of the spherical harmonics - use_rotate_inv_rescale (bool): Whether to pre-compute inverse rotation rescale matrices """ def __init__( self, - lmax_list, - mmax_list, - use_rotate_inv_rescale=False + lmax_list: list[int], + mmax_list: list[int], ): super().__init__() self.lmax_list = lmax_list self.mmax_list = mmax_list - self.use_rotate_inv_rescale = use_rotate_inv_rescale self.num_resolutions = len(lmax_list) - assert (len(self.lmax_list) == 1) and (len(self.mmax_list) == 1) - - # Compute the degree (l) and order (m) for each entry of the embedding - l_harmonic = torch.tensor([]).long() - m_harmonic = torch.tensor([]).long() - m_complex = torch.tensor([]).long() + # Temporarily use `cpu` as device and this will be overwritten. + self.device = "cpu" + + # Compute the degree (lval) and order (m) for each entry of the embedding + l_harmonic = torch.tensor([], device=self.device).long() + m_harmonic = torch.tensor([], device=self.device).long() + m_complex = torch.tensor([], device=self.device).long() - res_size = torch.zeros([self.num_resolutions]).long() + res_size = torch.zeros([self.num_resolutions], device=self.device).long() offset = 0 for i in range(self.num_resolutions): - for l in range(0, self.lmax_list[i] + 1): - mmax = min(self.mmax_list[i], l) - m = torch.arange(-mmax, mmax + 1).long() + for lval in range(self.lmax_list[i] + 1): + mmax = min(self.mmax_list[i], lval) + m = torch.arange(-mmax, mmax + 1, device=self.device).long() m_complex = torch.cat([m_complex, m], dim=0) - m_harmonic = torch.cat( - [m_harmonic, torch.abs(m).long()], dim=0 - ) - l_harmonic = torch.cat( - [l_harmonic, m.fill_(l).long()], dim=0 - ) + m_harmonic = torch.cat([m_harmonic, torch.abs(m).long()], dim=0) + l_harmonic = torch.cat([l_harmonic, m.fill_(lval).long()], dim=0) res_size[i] = len(l_harmonic) - offset offset = len(l_harmonic) num_coefficients = len(l_harmonic) # `self.to_m` moves m components from different L to contiguous index - to_m = torch.zeros([num_coefficients, num_coefficients]) - self.m_size = torch.zeros([max(self.mmax_list) + 1]).long().tolist() + to_m = torch.zeros([num_coefficients, num_coefficients], device=self.device) + m_size = torch.zeros([max(self.mmax_list) + 1], device=self.device).long() + # The following is implemented poorly - very slow. It only gets called + # a few times so haven't optimized. offset = 0 for m in range(max(self.mmax_list) + 1): idx_r, idx_i = self.complex_idx(m, -1, m_complex, l_harmonic) @@ -88,7 +84,7 @@ def __init__( to_m[idx_out + offset, idx_in] = 1.0 offset = offset + len(idx_r) - self.m_size[m] = int(len(idx_r)) + m_size[m] = int(len(idx_r)) for idx_out, idx_in in enumerate(idx_i): to_m[idx_out + offset, idx_in] = 1.0 @@ -97,124 +93,93 @@ def __init__( to_m = to_m.detach() # save tensors and they will be moved to GPU - self.register_buffer('l_harmonic', l_harmonic) - self.register_buffer('m_harmonic', m_harmonic) - self.register_buffer('m_complex', m_complex) - self.register_buffer('res_size', res_size) - self.register_buffer('to_m', to_m) - # self.register_buffer('m_size', m_size) - - self.pre_compute_coefficient_idx() - if self.use_rotate_inv_rescale: - self.pre_compute_rotate_inv_rescale() - + self.register_buffer("l_harmonic", l_harmonic) + self.register_buffer("m_harmonic", m_harmonic) + self.register_buffer("m_complex", m_complex) + self.register_buffer("res_size", res_size) + self.register_buffer("to_m", to_m) + self.register_buffer("m_size", m_size) + + # for caching the output of `coefficient_idx` + self.lmax_cache, self.mmax_cache = None, None + self.mask_indices_cache = None + self.rotate_inv_rescale_cache = None # Return mask containing coefficients of order m (real and imaginary parts) - def complex_idx(self, m, lmax, m_complex, l_harmonic): - ''' - Add `m_complex` and `l_harmonic` to the input arguments - since we cannot use `self.m_complex`. - ''' + def complex_idx(self, m: int, lmax: int, m_complex, l_harmonic): + """ + Add `m_complex` and `l_harmonic` to the input arguments + since we cannot use `self.m_complex`. + """ if lmax == -1: lmax = max(self.lmax_list) - indices = torch.arange(len(l_harmonic)) + indices = torch.arange(len(l_harmonic), device=self.device) # Real part - mask_r = torch.bitwise_and( - l_harmonic.le(lmax), m_complex.eq(m) - ) + mask_r = torch.bitwise_and(l_harmonic.le(lmax), m_complex.eq(m)) mask_idx_r = torch.masked_select(indices, mask_r) - mask_idx_i = torch.tensor([]).long() + mask_idx_i = torch.tensor([], device=self.device).long() # Imaginary part if m != 0: - mask_i = torch.bitwise_and( - l_harmonic.le(lmax), m_complex.eq(-m) - ) + mask_i = torch.bitwise_and(l_harmonic.le(lmax), m_complex.eq(-m)) mask_idx_i = torch.masked_select(indices, mask_i) return mask_idx_r, mask_idx_i - - def pre_compute_coefficient_idx(self): - ''' - Pre-compute the results of `coefficient_idx()` and access them with `prepare_coefficient_idx()` - ''' - lmax = max(self.lmax_list) - for l in range(lmax + 1): - for m in range(lmax + 1): - mask = torch.bitwise_and( - self.l_harmonic.le(l), self.m_harmonic.le(m) - ) - indices = torch.arange(len(mask)) - mask_indices = torch.masked_select(indices, mask) - self.register_buffer('coefficient_idx_l{}_m{}'.format(l, m), mask_indices) - - - def prepare_coefficient_idx(self): - ''' - Construct a list of buffers - ''' - lmax = max(self.lmax_list) - coefficient_idx_list = [] - for l in range(lmax + 1): - l_list = [] - for m in range(lmax + 1): - l_list.append(getattr(self, 'coefficient_idx_l{}_m{}'.format(l, m), None)) - coefficient_idx_list.append(l_list) - return coefficient_idx_list - - - # Return mask containing coefficients less than or equal to degree (l) and order (m) - def coefficient_idx(self, lmax, mmax): - if lmax > max(self.lmax_list) or mmax > max(self.lmax_list): - mask = torch.bitwise_and( - self.l_harmonic.le(lmax), self.m_harmonic.le(mmax) - ) - indices = torch.arange(len(mask), device=mask.device) - mask_indices = torch.masked_select(indices, mask) - return mask_indices - else: - temp = self.prepare_coefficient_idx() - return temp[lmax][mmax] - - - def pre_compute_rotate_inv_rescale(self): - lmax = max(self.lmax_list) - for l in range(lmax + 1): - for m in range(lmax + 1): - mask_indices = self.coefficient_idx(l, m) - rotate_inv_rescale = torch.ones((1, int((l + 1)**2), int((l + 1)**2))) - for l_sub in range(l + 1): - if l_sub <= m: - continue - start_idx = l_sub ** 2 - length = 2 * l_sub + 1 - rescale_factor = math.sqrt(length / (2 * m + 1)) - rotate_inv_rescale[:, start_idx : (start_idx + length), start_idx : (start_idx + length)] = rescale_factor - rotate_inv_rescale = rotate_inv_rescale[:, :, mask_indices] - self.register_buffer('rotate_inv_rescale_l{}_m{}'.format(l, m), rotate_inv_rescale) - - - def prepare_rotate_inv_rescale(self): - lmax = max(self.lmax_list) - rotate_inv_rescale_list = [] - for l in range(lmax + 1): - l_list = [] - for m in range(lmax + 1): - l_list.append(getattr(self, 'rotate_inv_rescale_l{}_m{}'.format(l, m), None)) - rotate_inv_rescale_list.append(l_list) - return rotate_inv_rescale_list - + # Return mask containing coefficients less than or equal to degree (lval) and order (m) + def coefficient_idx(self, lmax: int, mmax: int): + if ( + (self.lmax_cache is not None) + and (self.mmax_cache is not None) + and (self.lmax_cache == lmax) + and (self.mmax_cache == mmax) + and self.mask_indices_cache is not None + ): + return self.mask_indices_cache + + mask = torch.bitwise_and(self.l_harmonic.le(lmax), self.m_harmonic.le(mmax)) + self.device = mask.device + indices = torch.arange(len(mask), device=self.device) + mask_indices = torch.masked_select(indices, mask) + self.lmax_cache, self.mmax_cache = lmax, mmax + self.mask_indices_cache = mask_indices + return self.mask_indices_cache # Return the re-scaling for rotating back to original frame # this is required since we only use a subset of m components for SO(2) convolution - def get_rotate_inv_rescale(self, lmax, mmax): - temp = self.prepare_rotate_inv_rescale() - return temp[lmax][mmax] - - - def __repr__(self): + def get_rotate_inv_rescale(self, lmax: int, mmax: int): + if ( + (self.lmax_cache is not None) + and (self.mmax_cache is not None) + and (self.lmax_cache == lmax) + and (self.mmax_cache == mmax) + and self.rotate_inv_rescale_cache is not None + ): + return self.rotate_inv_rescale_cache + + if self.mask_indices_cache is None: + self.coefficient_idx(lmax, mmax) + + rotate_inv_rescale = torch.ones( + (1, (lmax + 1) ** 2, (lmax + 1) ** 2), device=self.device + ) + for lval in range(lmax + 1): + if lval <= mmax: + continue + start_idx = lval**2 + length = 2 * lval + 1 + rescale_factor = math.sqrt(length / (2 * mmax + 1)) + rotate_inv_rescale[ + :, + start_idx : (start_idx + length), + start_idx : (start_idx + length), + ] = rescale_factor + rotate_inv_rescale = rotate_inv_rescale[:, :, self.mask_indices_cache] + self.rotate_inv_rescale_cache = rotate_inv_rescale + return self.rotate_inv_rescale_cache + + def __repr__(self) -> str: return f"{self.__class__.__name__}(lmax_list={self.lmax_list}, mmax_list={self.mmax_list})" @@ -482,7 +447,7 @@ def __init__( ): super().__init__() self.lmax = lmax - self.mapping = CoefficientMappingModule([self.lmax], [self.lmax], use_rotate_inv_rescale=True) + self.mapping = CoefficientMappingModule([self.lmax], [self.lmax]) def set_wigner(self, rot_mat3x3): self.device, self.dtype = rot_mat3x3.device, rot_mat3x3.dtype @@ -517,7 +482,7 @@ def RotationToWignerDMatrix( ) gamma = torch.atan2(R[..., 0, 2], R[..., 0, 0]) - size = int((end_lmax + 1) ** 2) - int((start_lmax) ** 2) + size = (end_lmax + 1) ** 2 - (start_lmax) ** 2 wigner = torch.zeros(len(alpha), size, size, device=self.device) start = 0 for lmax in range(start_lmax, end_lmax + 1): @@ -546,7 +511,6 @@ def __init__( resolution: int | None = None, ): super().__init__() - self.lmax = lmax self.mmax = mmax self.lat_resolution = 2 * (self.lmax + 1) diff --git a/tests/core/models/test_eqv2_compiles.py b/tests/core/models/test_eqv2_compiles.py deleted file mode 100644 index d4e3390980..0000000000 --- a/tests/core/models/test_eqv2_compiles.py +++ /dev/null @@ -1,194 +0,0 @@ -""" -Copyright (c) Facebook, Inc. and its affiliates. - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -from __future__ import annotations - -import copy -import io -import os - -import pytest -import requests -import torch -from ase.io import read -from torch.nn.parallel.distributed import DistributedDataParallel - -from fairchem.core.common.registry import registry -from fairchem.core.common.distutils import init_local_distributed_process_group -from fairchem.core.common.utils import load_state_dict, setup_imports -from fairchem.core.datasets import data_list_collater -from fairchem.core.preprocessing import AtomsToGraphs -from torch_geometric.data import Data, Batch - -from fairchem.core.models.equiformer_v2.so3 import CoefficientMappingModule, SO3_Embedding -from fairchem.core.models.equiformer_v2.so2_ops import SO2_Convolution, SO2_Convolution_Exportable - -from torch.export import export -from torch.export import Dim - -def load_data(): - atoms = read( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "atoms.json"), - index=":", - format="json", - ) - a2g = AtomsToGraphs( - max_neigh=200, - radius=6, - r_edges=False, - r_fixed=True, - ) - data_list = a2g.convert_all(atoms) - return data_list - - -def load_model(): - torch.manual_seed(4) - setup_imports() - model = registry.get_model_class("equiformer_v2")( - use_pbc=True, - regress_forces=True, - otf_graph=True, - max_neighbors=20, - max_radius=12.0, - max_num_elements=90, - num_layers=8, - sphere_channels=128, - attn_hidden_channels=64, - num_heads=8, - attn_alpha_channels=64, - attn_value_channels=16, - ffn_hidden_channels=128, - norm_type="layer_norm_sh", - lmax_list=[4], - mmax_list=[2], - grid_resolution=18, - num_sphere_samples=128, - edge_channels=128, - use_atom_edge_embedding=True, - distance_function="gaussian", - num_distance_basis=512, - attn_activation="silu", - use_s2_act_attn=False, - ffn_activation="silu", - use_gate_act=False, - use_grid_mlp=True, - alpha_drop=0.1, - drop_path_rate=0.1, - proj_drop=0.0, - weight_init="uniform", - ) - return model - - -def init(backend="nccl"): - if not torch.distributed.is_initialized(): - init_local_distributed_process_group(backend=backend) - - -def expected_energy_forces(): - energy = torch.tensor([-0.0261]) - forces = torch.tensor([-0.0008, -0.0018, -0.0020]) - return energy, forces - - -def rand_input(natoms: int) -> BaseData: - data = Data(natoms=natoms, - pos=torch.rand(natoms, 3), - cell=torch.rand([1, 3, 3]), - atomic_numbers=torch.randint(1, 99, (1, 3, 3)) - ) - batch = Batch.from_data_list([data]) - return batch - -class TestEQV2Compiles: - def eqv2_baseline_output(self, backend: str): - init(backend=backend) - data = load_data() - data = data_list_collater([data[0]])#.to("cuda") - model = load_model()#.cuda() - ddp_model = DistributedDataParallel(model) - return ddp_model(data) - - def test_baseline_cpu(self): - outputs = self.eqv2_baseline_output("gloo") - energy, forces_mean = outputs["energy"].detach().cpu(), outputs["forces"].mean(0).detach().cpu() - expected_energy, expected_forces = expected_energy_forces() - assert torch.allclose(energy, expected_energy, atol=1e-4) - assert torch.allclose(forces_mean, expected_forces, atol=1e-4) - - def test_eqv2_compiles(self): - init() - data = load_data() - data0 = data_list_collater([data[0]]).to("cuda") - data1 = data_list_collater([data[1]]).to("cuda") - model = load_model().cuda() - ddp_model = DistributedDataParallel(model) - - torch._dynamo.config.optimize_ddp = False - torch._dynamo.config.assume_static_by_default = False - torch._dynamo.config.automatic_dynamic_shapes = True - torch._dynamo.config.cache_size_limit = 1 - # torch._dynamo.config.suppress_errors = True - - os.environ["TORCH_LOGS"] = "+dynamo,recompiles" - compiled_model = torch.compile(model, dynamic=True) - torch._dynamo.config.optimize_ddp = False - compiled_model(data0) - compiled_model(data1) - # import pdb; pdb.set_trace() - -class TestExportableEQV2: - def test_so2_conv_equivalent(self): - torch.manual_seed(4) - lmax, mmax = 4, 2 - sc, mc = 128, 128 - mappingReduced = CoefficientMappingModule([lmax], [mmax]) - - start_rng_state = torch.random.get_rng_state() - so2_export = SO2_Convolution_Exportable(sphere_channels=sc, m_output_channels=mc, lmax_list=[lmax], mmax_list=[mmax],mappingReduced=mappingReduced) - torch.random.set_rng_state(start_rng_state) - so2 = SO2_Convolution(sphere_channels=sc, m_output_channels=mc, lmax_list=[lmax], mmax_list=[mmax],mappingReduced=mappingReduced) - - inputs_tensor = (torch.rand(129, 19, 128), torch.rand(129, 856)) - inputs_embedding = SO3_Embedding(129, [lmax], 128, inputs_tensor[0].device, inputs_tensor[0].dtype) - inputs_embedding.set_embedding(inputs_tensor[0]) - assert torch.allclose(inputs_tensor[0], inputs_embedding.embedding) - output = so2(inputs_embedding, inputs_tensor[1]) - output_export = so2_export(*inputs_tensor) - assert torch.allclose(output.embedding, output_export) - - def test_so2_conv_exportable(self): - torch._dynamo.config.assume_static_by_default = False - torch._dynamo.config.automatic_dynamic_shapes = True - inp1_dim0 = Dim("inp1_dim0") - inp1_dim1 = None - inp1_dim2 = None - inp2_dim0 = inp1_dim0 - inp2_dim1 = Dim("inp2_dim1") - - dynamic_shapes1 = { - "x_emb": {0: inp1_dim0, 1: inp1_dim1, 2: inp1_dim2}, - "x_edge": {0: inp2_dim0, 1: inp2_dim1}, - } - - lmax, mmax = 4, 2 - mappingReduced = CoefficientMappingModule([lmax], [mmax]) - args=(torch.rand(129, 19, 128), torch.rand(129, 856)) - - so2 = SO2_Convolution_Exportable(sphere_channels=128, m_output_channels=128, lmax_list=[lmax], mmax_list=[mmax],mappingReduced=mappingReduced) - prog = export(so2, args=args, dynamic_shapes=dynamic_shapes1) - export_out = prog.module()(*args) - regular_out = so2(*args) - assert torch.allclose(export_out, regular_out) - - args2=(torch.rand(130, 19, 128), torch.rand(130, 856)) - export_out2 = prog.module()(*args2) - regular_out2 = so2(*args2) - assert torch.allclose(export_out2, regular_out2) - - From c0d8e412dd2ab77e1d17b477dd0d66c41b099f72 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Mon, 2 Sep 2024 16:27:01 -0700 Subject: [PATCH 16/25] compile works --- src/fairchem/core/_cli.py | 9 ++++----- src/fairchem/core/models/escn/escn_exportable.py | 8 +++----- src/fairchem/core/trainers/base_trainer.py | 7 +++---- src/fairchem/core/trainers/ocp_trainer.py | 5 +++++ tests/core/models/test_escn_compiles.py | 4 ++-- 5 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/fairchem/core/_cli.py b/src/fairchem/core/_cli.py index f1270496a9..47f5cb281b 100644 --- a/src/fairchem/core/_cli.py +++ b/src/fairchem/core/_cli.py @@ -15,6 +15,7 @@ from submitit.helpers import Checkpointable, DelayedSubmission from torch.distributed.launcher.api import LaunchConfig, elastic_launch +from fairchem.core.common.distutils import init_local_distributed_process_group from fairchem.core.common.flags import flags from fairchem.core.common.utils import ( build_config, @@ -94,7 +95,7 @@ def main(): logging.info(f"Experiment log saved to: {log_file}") else: # Run locally on a single node, n-processes - if args.distributed: + if args.num_gpus > 1: logging.info( f"Running in distributed local mode with {args.num_gpus} ranks" ) @@ -116,10 +117,8 @@ def main(): ) elastic_launch(launch_config, runner_wrapper)(args.distributed, config) else: - logging.info("Running in non-distributed local mode") - assert ( - args.num_gpus == 1 - ), "Can only run with a single gpu in non distributed local mode, use --distributed flag instead if using >1 gpu" + logging.info("Running in local mode") + init_local_distributed_process_group(backend='nccl') runner_wrapper(args.distributed, config) diff --git a/src/fairchem/core/models/escn/escn_exportable.py b/src/fairchem/core/models/escn/escn_exportable.py index dfeb4f0b23..bccc936a45 100644 --- a/src/fairchem/core/models/escn/escn_exportable.py +++ b/src/fairchem/core/models/escn/escn_exportable.py @@ -268,10 +268,8 @@ def forward(self, data): graph.edge_index, wigner, ) - # Residual layer for all layers past the first - x_xessage = x_message + x_message_new - + x_message = x_message + x_message_new else: # No residual for the first layer x_message = self.layer_blocks[i]( @@ -599,11 +597,11 @@ def forward( x_target = torch.einsum("bai,zbac->zic", from_grid_mat, x_grid) # Rotate back the irreps - wigner_inv = torch.transpose(wigner, 1, 2).contiguous() + wigner_inv = torch.transpose(wigner, 1, 2).contiguous().detach() x_target = torch.bmm(wigner_inv[:, :, self.out_mask], x_target) # Compute the sum of the incoming neighboring messages for each target node - new_embedding = torch.fill(x.clone(), 0) + new_embedding = torch.zeros(x.shape, dtype=x_target.dtype, device=x_target.device) new_embedding.index_add_(0, edge_index[1], x_target) return new_embedding diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index 6918d45fc5..7931785ef9 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -551,11 +551,10 @@ def load_model(self) -> None: device_ids=None if self.cpu else [self.device], ) - torch._dynamo.config.optimize_ddp = False - torch._dynamo.config.assume_static_by_default = False - torch._dynamo.config.automatic_dynamic_shapes = True - if self.config["optim"].get("compiles"): + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.assume_static_by_default = False + torch._dynamo.config.automatic_dynamic_shapes = True os.environ["TORCH_LOGS"] = "recompiles" self.model = torch.compile(self.model, dynamic=True) torch._dynamo.config.optimize_ddp = False diff --git a/src/fairchem/core/trainers/ocp_trainer.py b/src/fairchem/core/trainers/ocp_trainer.py index 0ced35bef3..f9c0ecdbec 100644 --- a/src/fairchem/core/trainers/ocp_trainer.py +++ b/src/fairchem/core/trainers/ocp_trainer.py @@ -12,6 +12,7 @@ from collections import defaultdict from itertools import chain from typing import TYPE_CHECKING +import time import numpy as np import torch @@ -139,6 +140,7 @@ def train(self, disable_eval_tqdm: bool = False) -> None: # Calculate start_epoch from step instead of loading the epoch number # to prevent inconsistencies due to different batch size in checkpoint. start_epoch = self.step // len(self.train_loader) + previous_wall_time = time.time() for epoch_int in range(start_epoch, self.config["optim"]["max_epochs"]): skip_steps = self.step % len(self.train_loader) @@ -182,6 +184,9 @@ def train(self, disable_eval_tqdm: bool = False) -> None: self.step % self.config["cmd"]["print_every"] == 0 and distutils.is_master() ): + time_delta = time.time() - previous_wall_time + previous_wall_time = time.time() + log_dict.update({'step_per_s' : self.config["cmd"]["print_every"] / time_delta}) log_str = [f"{k}: {v:.2e}" for k, v in log_dict.items()] logging.info(", ".join(log_str)) self.metrics = {} diff --git a/tests/core/models/test_escn_compiles.py b/tests/core/models/test_escn_compiles.py index e653ef4051..96d3c0e044 100644 --- a/tests/core/models/test_escn_compiles.py +++ b/tests/core/models/test_escn_compiles.py @@ -85,7 +85,7 @@ def init(backend: str): init_local_distributed_process_group(backend=backend) class TestESCNCompiles: - def test_escn_baseline_cpu(self, tol=1e-5): + def test_escn_baseline_cpu(self, tol=1e-8): init('gloo') data = load_data() data = data_list_collater([data]) @@ -98,7 +98,7 @@ def test_escn_baseline_cpu(self, tol=1e-5): assert torch.allclose(base_output["forces"].mean(0), export_output["forces"].mean(0), atol=tol) @skip_if_no_cuda - def test_escn_baseline_cuda(self, tol=1e-5): + def test_escn_baseline_cuda(self, tol=1e-8): init('nccl') data = load_data() data = data_list_collater([data]).to("cuda") From da1165beaf3e5244e6e9a91f0c2a80ab7e7cde60 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Mon, 9 Sep 2024 13:55:19 -0700 Subject: [PATCH 17/25] update --- src/fairchem/core/common/distutils.py | 10 -- src/fairchem/core/common/test_utils.py | 10 ++ src/fairchem/core/datasets/lmdb_dataset.py | 5 +- .../core/models/escn/escn_exportable.py | 87 +++++++--------- .../so3_utils.py => escn/so3_exportable.py} | 11 ++- .../core/preprocessing/atoms_to_graphs.py | 27 ++++- tests/core/models/test_escn_compiles.py | 98 ++++++++++--------- 7 files changed, 131 insertions(+), 117 deletions(-) rename src/fairchem/core/models/{utils/so3_utils.py => escn/so3_exportable.py} (98%) diff --git a/src/fairchem/core/common/distutils.py b/src/fairchem/core/common/distutils.py index 8bf5b1d426..297218179d 100644 --- a/src/fairchem/core/common/distutils.py +++ b/src/fairchem/core/common/distutils.py @@ -193,13 +193,3 @@ def gather_objects(data: T, group: dist.ProcessGroup = dist.group.WORLD) -> list output = [None for _ in range(get_world_size())] if is_master() else None dist.gather_object(data, output, group=group, dst=0) return output - -def init_local_distributed_process_group(backend="nccl"): - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = str(get_free_port()) - dist.init_process_group( - rank=0, - world_size=1, - backend=backend, - timeout=timedelta(seconds=10), # setting up timeout for distributed collectives - ) diff --git a/src/fairchem/core/common/test_utils.py b/src/fairchem/core/common/test_utils.py index 130daba2d5..ce86aa782f 100644 --- a/src/fairchem/core/common/test_utils.py +++ b/src/fairchem/core/common/test_utils.py @@ -130,3 +130,13 @@ def spawn_multi_process( ) return [mp_output_dict[i] for i in range(config.world_size)] + +def init_local_distributed_process_group(backend="nccl"): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(get_free_port()) + dist.init_process_group( + rank=0, + world_size=1, + backend=backend, + timeout=timedelta(seconds=10), # setting up timeout for distributed collectives + ) diff --git a/src/fairchem/core/datasets/lmdb_dataset.py b/src/fairchem/core/datasets/lmdb_dataset.py index ca1fcc2b77..ebcd08a261 100644 --- a/src/fairchem/core/datasets/lmdb_dataset.py +++ b/src/fairchem/core/datasets/lmdb_dataset.py @@ -211,7 +211,7 @@ def sample_property_metadata(self, num_samples: int = 100): } -def data_list_collater(data_list: list[BaseData], otf_graph: bool = False) -> BaseData: +def data_list_collater(data_list: list[BaseData], otf_graph: bool = False, to_dict: bool = False) -> BaseData | dict[str, torch.Tensor]: batch = Batch.from_data_list(data_list) if not otf_graph: @@ -226,4 +226,7 @@ def data_list_collater(data_list: list[BaseData], otf_graph: bool = False) -> Ba "LMDB does not contain edge index information, set otf_graph=True" ) + if to_dict: + batch = {k:v for k,v in batch.items()} + return batch diff --git a/src/fairchem/core/models/escn/escn_exportable.py b/src/fairchem/core/models/escn/escn_exportable.py index bccc936a45..99e6f0a6ee 100644 --- a/src/fairchem/core/models/escn/escn_exportable.py +++ b/src/fairchem/core/models/escn/escn_exportable.py @@ -15,16 +15,11 @@ import torch import torch.nn as nn -if typing.TYPE_CHECKING: - from torch_geometric.data.batch import Batch - from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad -from fairchem.core.models.base import BackboneInterface, GraphModelMixin, HeadInterface -from fairchem.core.models.utils.so3_utils import ( +from fairchem.core.models.escn.so3_exportable import ( CoefficientMapping, SO3_Grid, - SO3_Rotation, rotation_to_wigner ) from fairchem.core.models.scn.sampling import CalcSpherePoints @@ -40,7 +35,7 @@ @registry.register_model("escn_export") -class eSCN(nn.Module, GraphModelMixin): +class eSCN(nn.Module): """Equivariant Spherical Channel Network Paper: Reducing SO(3) Convolutions to SO(2) for Efficient Equivariant GNNs @@ -64,7 +59,6 @@ class eSCN(nn.Module, GraphModelMixin): distance_function ("gaussian", "sigmoid", "linearsigmoid", "silu"): Basis function used for distances basis_width_scalar (float): Width of distance basis function distance_resolution (float): Distance between distance basis functions in Angstroms - show_timing_info (bool): Show timing and memory info """ def __init__( @@ -86,7 +80,6 @@ def __init__( distance_function: str = "gaussian", basis_width_scalar: float = 1.0, distance_resolution: float = 0.02, - show_timing_info: bool = False, resolution: int | None = None, ) -> None: super().__init__() @@ -102,7 +95,6 @@ def __init__( self.use_pbc_single = use_pbc_single self.cutoff = cutoff self.otf_graph = otf_graph - self.show_timing_info = show_timing_info self.max_num_elements = max_num_elements self.hidden_channels = hidden_channels self.num_layers = num_layers @@ -121,9 +113,6 @@ def __init__( self.basis_width_scalar = basis_width_scalar self.distance_function = distance_function - # variables used for display purposes - self.counter = 0 - # non-linear activation function used throughout the network self.act = nn.SiLU() @@ -218,19 +207,22 @@ def __init__( ) - @conditional_grad(torch.enable_grad()) - def forward(self, data): - device = data.pos.device - self.batch_size = len(data.natoms) - self.dtype = data.pos.dtype - - start_time = time.time() - atomic_numbers = data.atomic_numbers.long() - assert ( - atomic_numbers.max().item() < self.max_num_elements - ), "Atomic number exceeds that given in model config" + # @conditional_grad(torch.enable_grad()) + def forward(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + pos: torch.Tensor = data["pos"] + batch_idx: torch.Tensor = data["batch"] + natoms: torch.Tensor = data["natoms"] + atomic_numbers: torch.Tensor = data["atomic_numbers"] + edge_index: torch.Tensor = data["edge_index"] + edge_distance: torch.Tensor = data["distances"] + edge_distance_vec: torch.Tensor = data["edge_distance_vec"] + + atomic_numbers = atomic_numbers.long() + # TODO: this requires upgrade to torch2.4 with export non-strict mode to enable + # assert ( + # atomic_numbers.max().item() < self.max_num_elements + # ), "Atomic number exceeds that given in model config" num_atoms = len(atomic_numbers) - graph = self.generate_graph(data) ############################################################### # Initialize data structures @@ -238,7 +230,7 @@ def forward(self, data): # Compute 3x3 rotation matrix per edge edge_rot_mat = self._init_edge_rot_mat( - data, graph.edge_index, graph.edge_distance_vec + edge_index, edge_distance_vec ) wigner = rotation_to_wigner(edge_rot_mat, 0, self.lmax).detach() @@ -251,8 +243,8 @@ def forward(self, data): num_atoms, int((self.lmax + 1) ** 2), self.sphere_channels, - device=device, - dtype=self.dtype, + device=pos.device, + dtype=pos.dtype, ) x_message[:, 0, :] = self.sphere_embedding(atomic_numbers) @@ -264,8 +256,8 @@ def forward(self, data): x_message_new = self.layer_blocks[i]( x_message, atomic_numbers, - graph.edge_distance, - graph.edge_index, + edge_distance, + edge_index, wigner, ) # Residual layer for all layers past the first @@ -275,8 +267,8 @@ def forward(self, data): x_message = self.layer_blocks[i]( x_message, atomic_numbers, - graph.edge_distance, - graph.edge_index, + edge_distance, + edge_index, wigner, ) @@ -288,8 +280,8 @@ def forward(self, data): # Energy estimation ############################################################### node_energy = self.energy_block(x_pt) - energy = torch.zeros(len(data.natoms), device=device) - energy.index_add_(0, data.batch, node_energy.view(-1)) + energy = torch.zeros(len(natoms), device=node_energy.device) + energy.index_add_(0, batch_idx, node_energy.view(-1)) # Scale energy to help balance numerical precision w.r.t. forces energy = energy * 0.001 @@ -301,30 +293,23 @@ def forward(self, data): forces = self.force_block(x_pt, self.sphere_points) outputs["forces"] = forces - if self.show_timing_info is True: - torch.cuda.synchronize() - logging.info( - f"{self.counter} Time: {time.time() - start_time}\tMemory: {len(data.pos)}\t{torch.cuda.max_memory_allocated() / 1000000}" - ) - - self.counter = self.counter + 1 - return outputs # Initialize the edge rotation matrics - def _init_edge_rot_mat(self, data, edge_index, edge_distance_vec): + def _init_edge_rot_mat(self, edge_index, edge_distance_vec): edge_vec_0 = edge_distance_vec edge_vec_0_distance = torch.sqrt(torch.sum(edge_vec_0**2, dim=1)) # Make sure the atoms are far enough apart - if torch.min(edge_vec_0_distance) < 0.0001: - logging.error( - f"Error edge_vec_0_distance: {torch.min(edge_vec_0_distance)}" - ) - (minval, minidx) = torch.min(edge_vec_0_distance, 0) - logging.error( - f"Error edge_vec_0_distance: {minidx} {edge_index[0, minidx]} {edge_index[1, minidx]} {data.pos[edge_index[0, minidx]]} {data.pos[edge_index[1, minidx]]}" - ) + # TODO: this requires upgrade to torch2.4 with export non-strict mode to enable + # if torch.min(edge_vec_0_distance) < 0.0001: + # logging.error( + # f"Error edge_vec_0_distance: {torch.min(edge_vec_0_distance)}" + # ) + # (minval, minidx) = torch.min(edge_vec_0_distance, 0) + # logging.error( + # f"Error edge_vec_0_distance: {minidx} {edge_index[0, minidx]} {edge_index[1, minidx]} {data.pos[edge_index[0, minidx]]} {data.pos[edge_index[1, minidx]]}" + # ) norm_x = edge_vec_0 / (edge_vec_0_distance.view(-1, 1)) diff --git a/src/fairchem/core/models/utils/so3_utils.py b/src/fairchem/core/models/escn/so3_exportable.py similarity index 98% rename from src/fairchem/core/models/utils/so3_utils.py rename to src/fairchem/core/models/escn/so3_exportable.py index 61e37744e0..5c9fe06bac 100644 --- a/src/fairchem/core/models/utils/so3_utils.py +++ b/src/fairchem/core/models/escn/so3_exportable.py @@ -14,7 +14,10 @@ # Borrowed from e3nn @ 0.4.0: # https://github.com/e3nn/e3nn/blob/0.4.0/e3nn/o3/_wigner.py#L10 # _Jd is a list of tensors of shape (2l+1, 2l+1) -_Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt")) +__Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt")) +@torch.compiler.assume_constant_result +def get_jd() -> torch.Tensor: + return __Jd # Borrowed from e3nn @ 0.4.0: @@ -25,10 +28,8 @@ def wigner_D( lv: int, alpha: torch.Tensor, beta: torch.Tensor, gamma: torch.Tensor ) -> torch.Tensor: - if not lv < len(_Jd): - raise NotImplementedError( - f"wigner D maximum l implemented is {len(_Jd) - 1}, send us an email to ask for more" - ) + _Jd = get_jd() + assert lv < len(_Jd), f"wigner D maximum l implemented is {len(_Jd) - 1}, send us an email to ask for more" alpha, beta, gamma = torch.broadcast_tensors(alpha, beta, gamma) J = _Jd[lv].to(dtype=alpha.dtype, device=alpha.device) diff --git a/src/fairchem/core/preprocessing/atoms_to_graphs.py b/src/fairchem/core/preprocessing/atoms_to_graphs.py index 2283c40b8a..22f6c471ef 100644 --- a/src/fairchem/core/preprocessing/atoms_to_graphs.py +++ b/src/fairchem/core/preprocessing/atoms_to_graphs.py @@ -115,7 +115,6 @@ def _get_neighbors_pymatgen(self, atoms: ase.Atoms): _c_index, _n_index, _offsets, n_distance = struct.get_neighbor_list( r=self.radius, numerical_tol=0, exclude_self=True ) - _nonmax_idx = [] for i in range(len(atoms)): idx_i = (_c_index == i).nonzero()[0] @@ -148,6 +147,28 @@ def _reshape_features(self, c_index, n_index, n_distance, offsets): return edge_index, edge_distances, cell_offsets + def get_edge_distance_vec( + self, + pos, + edge_index, + cell, + cell_offsets, + ): + row, col = edge_index + distance_vectors = pos[row] - pos[col] + + # correct for pbc + cell = torch.repeat_interleave(cell, edge_index.shape[1], dim=0) + offsets = cell_offsets.float().view(-1, 1, 3).bmm(cell.float()).view(-1, 3) + distance_vectors += offsets + + # redundancy: remove zero distances + # TODO: do we need this? + # distances = distance_vectors.norm(dim=-1) + # nonzero_idx = torch.arange(len(distances))[distances != 0] + + return distance_vectors + def convert(self, atoms: ase.Atoms, sid=None): """Convert a single atomic structure to a graph. @@ -158,7 +179,7 @@ def convert(self, atoms: ase.Atoms, sid=None): tasks. Common sids used in OCP datasets include unique strings or integers. Returns: - data (torch_geometric.data.Data): A torch geometic data object with positions, atomic_numbers, tags, + data (torch_geometric.dqata.Data): A torch geometic data object with positions, atomic_numbers, tags, and optionally, energy, forces, distances, edges, and periodic boundary conditions. Optional properties can included by setting r_property=True when constructing the class. """ @@ -203,6 +224,8 @@ def convert(self, atoms: ase.Atoms, sid=None): data.edge_index = edge_index data.cell_offsets = cell_offsets + data.edge_distance_vec = self.get_edge_distance_vec(positions, edge_index, cell, cell_offsets) + del atoms_copy if self.r_energy: energy = atoms.get_potential_energy(apply_constraint=False) diff --git a/tests/core/models/test_escn_compiles.py b/tests/core/models/test_escn_compiles.py index 96d3c0e044..dee01fcf56 100644 --- a/tests/core/models/test_escn_compiles.py +++ b/tests/core/models/test_escn_compiles.py @@ -21,7 +21,7 @@ from torch.nn.parallel.distributed import DistributedDataParallel from fairchem.core.common.registry import registry -from fairchem.core.common.distutils import init_local_distributed_process_group +from fairchem.core.common.test_utils import init_local_distributed_process_group from fairchem.core.common.utils import load_state_dict, setup_imports from fairchem.core.datasets import data_list_collater from fairchem.core.preprocessing import AtomsToGraphs @@ -29,7 +29,7 @@ from fairchem.core.models.scn.smearing import GaussianSmearing from fairchem.core.models.base import GraphModelMixin -from fairchem.core.models.utils.so3_utils import CoefficientMapping, SO3_Grid, rotation_to_wigner +from fairchem.core.models.escn.so3_exportable import CoefficientMapping, SO3_Grid, rotation_to_wigner from fairchem.core.models.escn import escn_exportable from torch.export import export @@ -45,10 +45,11 @@ def load_data(): format="json", ) a2g = AtomsToGraphs( - max_neigh=200, - radius=6, - r_edges=False, + max_neigh=300, + radius=6.0, + r_edges=True, r_fixed=True, + r_distances=True, ) data_list = a2g.convert_all([atoms]) return data_list[0] @@ -62,8 +63,8 @@ def load_model(name: str): use_pbc_single = False, regress_forces = True, otf_graph = True, - max_neighbors = 20, - cutoff = 8.0, + max_neighbors = 300, + cutoff = 6.0, max_num_elements = 90, num_layers = 8, lmax_list = [4], @@ -75,7 +76,6 @@ def load_model(name: str): distance_function = "gaussian", basis_width_scalar = 1.0, distance_resolution = 0.02, - show_timing_info = False, resolution = None, ) return model @@ -88,11 +88,13 @@ class TestESCNCompiles: def test_escn_baseline_cpu(self, tol=1e-8): init('gloo') data = load_data() - data = data_list_collater([data]) + data_tg = data_list_collater([data]) + data_export = data_list_collater([data], to_dict=True) + base_model = DistributedDataParallel(load_model("escn")) export_model = DistributedDataParallel(load_model("escn_export")) - base_output = base_model(data) - export_output = export_model(data) + base_output = base_model(data_tg) + export_output = export_model(data_export) torch.set_printoptions(precision=8) assert torch.allclose(base_output["energy"], export_output["energy"], atol=tol) assert torch.allclose(base_output["forces"].mean(0), export_output["forces"].mean(0), atol=tol) @@ -101,11 +103,14 @@ def test_escn_baseline_cpu(self, tol=1e-8): def test_escn_baseline_cuda(self, tol=1e-8): init('nccl') data = load_data() - data = data_list_collater([data]).to("cuda") - base_model = DistributedDataParallel(load_model("escn")).cuda() - export_model = DistributedDataParallel(load_model("escn_export")).cuda() - base_output = base_model(data) - export_output = export_model(data) + data_tg = data_list_collater([data]).to("cuda") + data_export = data_list_collater([data], to_dict=True) + data_export_cu = {k:v.to("cuda") for k,v in data_export.items()} + + base_model = DistributedDataParallel(load_model("escn").cuda()) + export_model = DistributedDataParallel(load_model("escn_export").cuda()) + base_output = base_model(data_tg) + export_output = export_model(data_export_cu) torch.set_printoptions(precision=8) assert torch.allclose(base_output["energy"], export_output["energy"], atol=tol) assert torch.allclose(base_output["forces"].mean(0), export_output["forces"].mean(0), atol=tol) @@ -120,7 +125,7 @@ def test_rotation_invariance(self) -> None: assert not np.array_equal(data.pos, data_rotated.pos) # Pass it through the model. - batch = data_list_collater([data, data_rotated]) + batch = data_list_collater([data, data_rotated], to_dict=True) model = load_model("escn_export") model.eval() out = model(batch) @@ -201,19 +206,11 @@ def test_escn_message_block_exports_and_compiles(self, tol=1e-5) -> None: mappingReduced = mappingReduced ) - data = load_data() - data = data_list_collater([data]) - full_model = load_model("escn_export") - graph = full_model.generate_graph(data) - edge_rot_mat = full_model._init_edge_rot_mat( - data, graph.edge_index, graph.edge_distance_vec - ) - wigner = rotation_to_wigner(edge_rot_mat, 0, lmax).detach() - # generate inputs batch_sizes = [34] num_coefs = 25 - num_edges = 680 + num_edges = 2000 + wigner = torch.rand([num_edges, num_coefs, num_coefs]) args = [] for b in batch_sizes: x = torch.rand([b, num_coefs, sphere_channels]) @@ -279,10 +276,6 @@ def test_escn_layer_block_exports_and_compiles(self, tol=1e-5) -> None: torch._dynamo.config.assume_static_by_default = False torch._dynamo.config.automatic_dynamic_shapes = True torch._dynamo.config.verbose = True - # torch._logging.set_logs(dynamo = logging.INFO) - # torch._dynamo.reset() - # explain_output = torch._dynamo.explain(message_block)(*args[0]) - # print(explain_output) batch_dim = Dim("batch_dim") edges_dim = Dim("edges_dim") @@ -302,31 +295,40 @@ def test_escn_layer_block_exports_and_compiles(self, tol=1e-5) -> None: assert torch.allclose(compiled_output, regular_out, atol=tol) assert torch.allclose(exported_output, regular_out, atol=tol) - def test_escn_compiles(self): + def test_full_escn_compiles(self, tol=1e-5): init("gloo") data = load_data() - data = data_list_collater([data]) + regular_data = data_list_collater([data]) + compile_data = data_list_collater([data], to_dict=True) model = load_model('escn_export') ddp_model = DistributedDataParallel(model) torch._dynamo.config.optimize_ddp = False torch._dynamo.config.assume_static_by_default = False torch._dynamo.config.automatic_dynamic_shapes = True - # torch._dynamo.config.suppress_errors = True + compiled_model = torch.compile(ddp_model, dynamic=True) + output = compiled_model(compile_data) + expected_output = ddp_model(regular_data) + assert torch.allclose(expected_output["energy"], output["energy"], atol=tol) + assert torch.allclose(expected_output["forces"].mean(0), output["forces"].mean(0), atol=tol) + + def test_full_escn_exports(self): + init("gloo") + data = load_data() + regular_data = data_list_collater([data]) + export_data = data_list_collater([data], to_dict=True) + model = load_model('escn_export') - # os.environ["TORCH_LOGS"] = "+dynamo,recompiles" + torch._dynamo.config.optimize_ddp = False + torch._dynamo.config.assume_static_by_default = False + torch._dynamo.config.automatic_dynamic_shapes = True # torch._logging.set_logs(dynamo = logging.INFO) - # os.environ["TORCHDYNAMO_VERBOSE"] = "1" - # os.environ["TORCHDYNAMO_REPRO_AFTER"]="dynamo" - # torch._dynamo.config.verbose = True - compiled_model = torch.compile(model, dynamic=True) - # torch._dynamo.explain(model)(data) - # assert False # torch._dynamo.reset() - # explain_output = torch._dynamo.explain(model)(data) - # print(explain_output) - - output = compiled_model(data) - # expected_energy, expected_forces = expected_energy_forces() - # assert torch.allclose(output["energy"], expected_energy) - # assert torch.allclose(output["forces"].mean(0), expected_forces) + # explained_output = torch._dynamo.explain(model)(*data) + # print(explained_output) + # TODO: add dynamic shapes + exported_prog = export(model, args=(export_data,)) + export_output = exported_prog(export_data) + expected_output = model(regular_data) + assert torch.allclose(export_output["energy"], expected_output["energy"]) + assert torch.allclose(export_output["forces"].mean(0), expected_output["forces"].mean(0)) From 59ea9db0d246d1067dbba2cc9ae607390fc5ea6f Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Mon, 9 Sep 2024 14:04:37 -0700 Subject: [PATCH 18/25] remove some files from main --- src/fairchem/core/_cli.py | 9 +++++---- src/fairchem/core/common/distutils.py | 1 - src/fairchem/core/models/utils/Jd.pt | Bin 21697 -> 0 bytes tests/core/models/atoms.json | 22 ++-------------------- 4 files changed, 7 insertions(+), 25 deletions(-) delete mode 100644 src/fairchem/core/models/utils/Jd.pt diff --git a/src/fairchem/core/_cli.py b/src/fairchem/core/_cli.py index 47f5cb281b..f1270496a9 100644 --- a/src/fairchem/core/_cli.py +++ b/src/fairchem/core/_cli.py @@ -15,7 +15,6 @@ from submitit.helpers import Checkpointable, DelayedSubmission from torch.distributed.launcher.api import LaunchConfig, elastic_launch -from fairchem.core.common.distutils import init_local_distributed_process_group from fairchem.core.common.flags import flags from fairchem.core.common.utils import ( build_config, @@ -95,7 +94,7 @@ def main(): logging.info(f"Experiment log saved to: {log_file}") else: # Run locally on a single node, n-processes - if args.num_gpus > 1: + if args.distributed: logging.info( f"Running in distributed local mode with {args.num_gpus} ranks" ) @@ -117,8 +116,10 @@ def main(): ) elastic_launch(launch_config, runner_wrapper)(args.distributed, config) else: - logging.info("Running in local mode") - init_local_distributed_process_group(backend='nccl') + logging.info("Running in non-distributed local mode") + assert ( + args.num_gpus == 1 + ), "Can only run with a single gpu in non distributed local mode, use --distributed flag instead if using >1 gpu" runner_wrapper(args.distributed, config) diff --git a/src/fairchem/core/common/distutils.py b/src/fairchem/core/common/distutils.py index 297218179d..f6bf88ccaf 100644 --- a/src/fairchem/core/common/distutils.py +++ b/src/fairchem/core/common/distutils.py @@ -17,7 +17,6 @@ import torch.distributed as dist from fairchem.core.common.typing import none_throws -from torch.distributed.elastic.utils.distributed import get_free_port T = TypeVar("T") diff --git a/src/fairchem/core/models/utils/Jd.pt b/src/fairchem/core/models/utils/Jd.pt deleted file mode 100644 index 01ed13e3f1c78bfcdf09077adc30aa3ab4a64685..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 21697 zcmdU13tWxa7C-4JglO{4qd`M $72cRX^$B$8C72yG6Ady!Yec%(55V;D3>7*{i9 zOvyNDxH6uFL`a^Ayk^Ek9(R9dxA!{V>U4j%JISej-*30RZ~ymRYkljr_t_m^m8qFR zVP&N#^;1V-s|cPN5*{%l%r!Jv6YLT-F|x0>jUqty6J0HkDB6Z-rqG@)LDMx6k<+Y$ zriP819uXNDqzRimZOYW386MWeq-|S&OFG9i+B5j0Fj9>{tb0$HJ}xqB7`tmtO#Mx0 zCpTi|Z%W&csOiL9WuhS#LsbfuiOM8~ShB`yX-LSF$jGn|O~jPR)2xR~4Go(b7TP-^ zL}N{=JE@7ahSX4#n&EWX@M_^~RCiJ<++0;%Lu{ncOjV{aq&98npmmmpOQ5%V~8VdDCk`VYONt{)TFHp)R8z*&@2skPXg6Y)rcYO zXhT8m8OT{fT-3x>2I@rIC}@XZ$6*kbbnGsQwH!Ktp`gL?uwwsIj=ak%1JJOksm0n4QWlh76_+g$-da zUkw?mCTcm%oeZO}y45tqPlDO2>|@Aq+ECaC2K!J${MBS+n2WE<(rW1G*bRkrKI!r; z(|K$Z_A=*jb?fV054n!ck0|5x=aEtx`fk0Vqmm;^*MA$jn|x^ZzTTY;4K4H2vIuE* z_<5v-zoox0Eq@9W9l&8&4i|0eT{=@8OUZQxO@W&S$X5N@Do^u>)w~BsIe;&{LHu5*o z56|O2Z)MG7|2yLO6F2iMuk*76UOukzQm)_J7(FTFijZS=lf)rUJkAMvwfD)Msx`NT zT$@)cANA|wOXB%Ri@zL4CnSmITvDt*-Tme_MZcqM$21+)@|qZ@;Fi1d@ZDL09`ZEu z{HH^5M1MZ6vA2$UMzv^=dsfJyyP!z z6*xA}7BSD>mUX6uetkvId|YF1KRp|8=-TsiAy=0t4^CWc0&=!_vTDHOuYukgzqi&c z_0OVzR{f$S`PD$bTwd(I*#5Bo81B5#-@n5Y#Xav{6*M2$*ju|^gO=@5CJFr%-F4lY zl${~yZ@<_y!D0SoK`(qA5$hLrUC;?#KDylP!VN)ld2#&1{)_F;p#GwBm3+J^(MCg_ z(OB=b+BqBI*W*K{Zgy-7e(Zhx{e~H_*+OpxX01jQ$L0t+#zKiAzF8{Qe>!o5zJGZMAQ!hQx_oUkj+<+c*X zO`P9!^XSEA9AB`1R26+<`@{NUa;y^iT47QcImjRK!^UQ{Pppc+A)a?wI_R3&SJ_%z zU*Y-y=QkXmaeS#N`o#W=?GNiuC)Z}s=h_V(zXZfV++=onNQvuHTwme(0OvOxpR0;K zas0#ni|r5VPbX*eNnxM&LziYvj@zJR_6PMnu1|4&h3kW|@|&^imFEx6Cy#%e=KjlR zZhxHS`r|Zf$LiLbr+nP1Lk)KwV4=QpN$8=(?u4A1PKUQSaSql;&o{QU znh*f%#9DqefBZ*Yrs$t|Zi43-r`=*at_iGPcK@2__kE2!&Fk&HDaPSA-UiRJZGF86 z_7V8~s@67W{ON^+s~3da7w$C*e6;wGpo{KXZp+yN>!pHb{l;!ih4udF-BvkH&q4oV zmf4=PzYBFP*9W#=gZhm=D)hamWz@q(Nni(!SJyOm7y46U_T`;e6UZUM`^#u@d4`z>z6$T z`(Ijpk6+Z>bAm67uPXF0%X2=%cfdM4Fgto?dIH$hwpKMv9y**7a<=pxab-bhg3$k; z8~nW1dtDT?-^b0bJ^T{lKi30}zc@Z%`!%fJBZ3b;ZdGezGuo^U7zX-2pEPpq^Fr|3 zjUAK64eJQ>{Rq2H_T}aXx!HOa^4zgkliGEd4CCQ^jPnJKXV~AcJz##Tn!e@yZWDI; zI`?()>sMMquSsvbIeh(bL2Ezz>V;bz#PK;mnPWf#F&``&ziHT=TS`9agjIh$ zn{Jz-#q~O_Cvm-l^D)jBIG*wOs#=TV1GZnRU!8mnpblIx$|0%q`m4gee^kX(KN|d# z60e7Fy^iZiT<_p~jPpfR(=+E6$6p*Du>E5F>c+X2ChU7<{Jp~Du$5Z8-o)!6T(9GL z64yI8A9MbzTFdhr=bOi0PV@M{X>Pxq=KAF{YkhV5A6}2U7c8lxP)JwV`yVz6^!dL> z#eM!S=c??_{~rl^!3e`NtV4DV^v|EY?_P<2|F7o@U{pDt-@nH6*WPnkI=v-b<j$)!d0`YYGLzAWdQwdK{P94%2Lfe|!eR^XFGEKiPs0wy&NiXy>%m z8FhBvQDXiu|D4~y+v54zc(A|0^?~0X=I`(F!SP&vPV@OW&DvC*{V!=?_IL?BmL^^0 zu>X#gVgEB-ZC1EWgZ+=#ZJjcA=R;l1;={2r>Dy|E z=MNf)?|e?f{J}W9Uq!bpE<)Xm?PPH?yX2LAaJ~Y+4=gVpucG;tkC!U^J=C_fPrv)H zk5lu+#*{V1_m$XQ@cY2>%Ezxb3iHQCE)e{`cn}iddH#|X+cUNo{5~p<=q>H*(yYqZq$%PN&4)1z&7ER(ii4^UpX;RB}$|9&Xj% z`&nrmfBvRgZ2xQ?h(Ct;Rw%Ik4f7BC1#gOf`M9v($@~e~=cgT}yj(qX$$s&k(h4Wfu^a>PbAIvrGb|tEbB?R**74_lpF2zk`I^M0 zdzc47-ubu&#C8Gb=g42kI5<57=I?KJk2&Ew6mOAH{p@F(Raok29R(XumjR zfORtD%l=z0MFot%r^Wt`{Q=tx)~9^_%A>G+x|yS4Ur*ER;E_|^ArEdnFzSQZ@wrMI z<yG@7N!(y;Por@|p1a!}2ltVI9xp{8fwNHI8T4-?2Z)xu`sf+b5rqzdufM`8dtm zRl5A!uU>0^EBa5a(p8T9+xczeUpfB@MbEXPG6w$wb&|cg!{jrIVc&x7Lmd_IoZ&YM zzn719K=e!cp~!sA=WsqC{u7sNN$Ie^$>xP~N%C>t9);~O=AS-um%j)3+uk>C?7@FQ zz0&SQLg0eQXN26Gb1c8^e)Mh8FZk*6tBHlj1zloqT8=lC_Qu()7U|0~1fTumGF&DG z!nysKrURxm-wNk#b{4$ezSFPtzBPXTSbipNhLFE|)FjhpufPtCMIR;SK!rYR>#HAo z2f;qJX8YK80`npNm*hLqpUZ*mAHRQt@<+pZjgRwI+i0+hbjyXymN$gD^<4KfkM_f$ zF2a75m}u6txYkW2jz72`wm+=@w<^EhP zkEOTzCw~8i<%e_Es<`KaF9+Tc{?Fpbc_q$2IR4`JgZ=NVDnV~@Y}a~^bbn&`*>kWT zw!iUHiGCpe;{1c-Z&~pN&tKUz_kTVww|`Fa_s?m${QPdTuKxP+))5D*>a_l%S2^mh zF2+!QtsA|i_3$~cj%by#)BLl6P$xXwzgAht5&kdnR4?a^nX91Ax)_;zVn)#_;Wu0} ztcCqx{C|l@cIWK#&Vc$oCC+c2y+!G~Eg$dgQ!d4;wk#VB>&o9p2KR24{hi3Ka|#!9 zOIiqW`97JIkm?HSGH2J96PJV^7X7(ISU>DO>>shue~t+~Z*OK6*Y*I6%f~UcmaVrg z3ckV5parh7Np^~w^? zg_i8U2z{A7o)mj|$00%E_5IP2SAYN04%Rt*eC%J?ehlm9gqVkqV{B~`;`L(rY=}>p z$DK}nH394e*R>^iP{=j%^sKzv+db=EvCdCFcj?7Tf^~3lBeRS4Z%&0dD)?IA9 z3G1??K!w9A%RDX4r#N5Wc!>QS+q1FhiSvW)2kVFROB41r^ZNiLxd;0icl!8ye3fuq ziR%rVPjSA$@eunv=g-($>|fY^uzqy>_kjP)`}BroBl}+xw74F_^#;zTIA7p+Xl#1o z{P6h0Y3^T~=JvyBt{+abHcD5&Uuvy9)V~J({E}W}*YEn@A2uUw>NNTVfPdrvzdfy` zZ(hsT^wDs7XigzrzA6)`5#NcZ|AlY#U^e=}-;jvW`rqkhBg1y5r)>1mu5K8s|9xt5 z>^h8f$5{RUrLQfU*PnjhqikOHQWNVRM3;lt*U)|To822>yI`!ouOQ>_1#F3D`l>7}+LxGLGgD__h Date: Mon, 9 Sep 2024 15:23:41 -0700 Subject: [PATCH 19/25] lint --- configs/s2ef/2M/base.yml | 4 +- src/fairchem/core/common/utils.py | 8 ++-- .../core/models/escn/escn_exportable.py | 25 +++------- .../core/preprocessing/atoms_to_graphs.py | 2 +- src/fairchem/core/trainers/ocp_trainer.py | 5 -- tests/core/models/test_escn_compiles.py | 47 +++++++++---------- 6 files changed, 35 insertions(+), 56 deletions(-) diff --git a/configs/s2ef/2M/base.yml b/configs/s2ef/2M/base.yml index 69d8401bdd..cea1f121b0 100755 --- a/configs/s2ef/2M/base.yml +++ b/configs/s2ef/2M/base.yml @@ -3,7 +3,7 @@ trainer: ocp dataset: train: format: lmdb - src: /home/rgao/s2ef/s2ef/200k/train/ + src: data/s2ef/2M/train/ key_mapping: y: energy force: forces @@ -16,7 +16,7 @@ dataset: mean: 0 stdev: 2.887317180633545 val: - src: /home/rgao/s2ef/s2ef/200k/train/ + src: data/s2ef/all/val_id/ logger: wandb diff --git a/src/fairchem/core/common/utils.py b/src/fairchem/core/common/utils.py index b0eff00ce6..955ea1e062 100644 --- a/src/fairchem/core/common/utils.py +++ b/src/fairchem/core/common/utils.py @@ -763,7 +763,6 @@ def radius_graph_pbc( atom_distance=atom_distance_sqr, max_num_neighbors_threshold=max_num_neighbors_threshold, enforce_max_strictly=enforce_max_neighbors_strictly, - batch=data.batch, ) if not torch.all(mask_num_neighbors): @@ -787,7 +786,6 @@ def get_max_neighbors_mask( max_num_neighbors_threshold, degeneracy_tolerance: float = 0.01, enforce_max_strictly: bool = False, - batch=None, ): """ Give a mask that filters out edges so that each atom has at most @@ -810,12 +808,14 @@ def get_max_neighbors_mask( # Get number of neighbors # segment_coo assumes sorted index ones = index.new_ones(1).expand_as(index) - num_neighbors = scatter(ones, index, dim_size=num_atoms) + num_neighbors = segment_coo(ones, index, dim_size=num_atoms) max_num_neighbors = num_neighbors.max() num_neighbors_thresholded = num_neighbors.clamp(max=max_num_neighbors_threshold) # Get number of (thresholded) neighbors per image - num_neighbors_image = scatter(num_neighbors_thresholded, batch, dim_size=natoms.shape[0]) + image_indptr = torch.zeros(natoms.shape[0] + 1, device=device, dtype=torch.long) + image_indptr[1:] = torch.cumsum(natoms, dim=0) + num_neighbors_image = segment_csr(num_neighbors_thresholded, image_indptr) # If max_num_neighbors is below the threshold, return early if ( diff --git a/src/fairchem/core/models/escn/escn_exportable.py b/src/fairchem/core/models/escn/escn_exportable.py index 99e6f0a6ee..2a651827d2 100644 --- a/src/fairchem/core/models/escn/escn_exportable.py +++ b/src/fairchem/core/models/escn/escn_exportable.py @@ -9,18 +9,15 @@ import contextlib import logging -import time -import typing import torch import torch.nn as nn from fairchem.core.common.registry import registry -from fairchem.core.common.utils import conditional_grad from fairchem.core.models.escn.so3_exportable import ( CoefficientMapping, SO3_Grid, - rotation_to_wigner + rotation_to_wigner, ) from fairchem.core.models.scn.sampling import CalcSpherePoints from fairchem.core.models.scn.smearing import ( @@ -71,8 +68,8 @@ def __init__( cutoff: float = 8.0, max_num_elements: int = 90, num_layers: int = 8, - lmax_list: List[int] = [4], # list of 1, for backward compat only right now, - mmax_list: List[int] = [2], # list of 1, for backward compat only right now, + lmax_list: list[int] = (4), # list of 1, for backward compat only right now, + mmax_list: list[int] = (2), # list of 1, for backward compat only right now, sphere_channels: int = 128, hidden_channels: int = 256, edge_channels: int = 128, @@ -107,7 +104,8 @@ def __init__( self.grad_forces = False self.lmax_list = lmax_list self.mmax_list = mmax_list - assert len(self.lmax_list) == 1 and len(self.mmax_list) == 1 + assert len(self.lmax_list) == 1 + assert len(self.mmax_list) == 1 self.lmax = lmax_list[0] self.mmax = mmax_list[0] self.basis_width_scalar = basis_width_scalar @@ -443,9 +441,7 @@ def forward( ) # Compute point-wise spherical non-linearity on aggregated messages - # Project to grid - # x_grid_message = x_message.to_grid(self.SO3_grid["lmax_lmax"]) to_grid_mat = self.SO3_grid["lmax_lmax"].to_grid_mat[:, :, self.SO3_grid["lmax_lmax"].mapping.coefficient_idx(self.lmax, self.lmax)] x_grid_message = torch.einsum("bai,zic->zbac", to_grid_mat, x_message) @@ -459,12 +455,8 @@ def forward( x_grid = self.fc3_sphere(x_grid) # Project back to spherical harmonic coefficients - # x_message._from_grid(x_grid, self.SO3_grid["lmax_lmax"]) from_grid_mat = self.SO3_grid["lmax_lmax"].from_grid_mat[:, :, self.SO3_grid["lmax_lmax"].mapping.coefficient_idx(self.lmax, self.lmax)] - x_message_final = torch.einsum("bai,zbac->zic", from_grid_mat, x_grid) - - # Return aggregated messages - return x_message_final + return torch.einsum("bai,zbac->zic", from_grid_mat, x_grid) class MessageBlock(torch.nn.Module): @@ -686,10 +678,7 @@ def forward( offset = offset + 2 * self.mappingReduced.m_size[m] # Reshape the spherical harmonics based on l (degree) - # x._l_primary(self.mappingReduced) - x = torch.einsum("nac,ab->nbc", x, self.mappingReduced.to_m) - - return x + return torch.einsum("nac,ab->nbc", x, self.mappingReduced.to_m) class SO2Conv(torch.nn.Module): diff --git a/src/fairchem/core/preprocessing/atoms_to_graphs.py b/src/fairchem/core/preprocessing/atoms_to_graphs.py index 22f6c471ef..8d418e091d 100644 --- a/src/fairchem/core/preprocessing/atoms_to_graphs.py +++ b/src/fairchem/core/preprocessing/atoms_to_graphs.py @@ -179,7 +179,7 @@ def convert(self, atoms: ase.Atoms, sid=None): tasks. Common sids used in OCP datasets include unique strings or integers. Returns: - data (torch_geometric.dqata.Data): A torch geometic data object with positions, atomic_numbers, tags, + data (torch_geometric.data.Data): A torch geometic data object with positions, atomic_numbers, tags, and optionally, energy, forces, distances, edges, and periodic boundary conditions. Optional properties can included by setting r_property=True when constructing the class. """ diff --git a/src/fairchem/core/trainers/ocp_trainer.py b/src/fairchem/core/trainers/ocp_trainer.py index f9c0ecdbec..0ced35bef3 100644 --- a/src/fairchem/core/trainers/ocp_trainer.py +++ b/src/fairchem/core/trainers/ocp_trainer.py @@ -12,7 +12,6 @@ from collections import defaultdict from itertools import chain from typing import TYPE_CHECKING -import time import numpy as np import torch @@ -140,7 +139,6 @@ def train(self, disable_eval_tqdm: bool = False) -> None: # Calculate start_epoch from step instead of loading the epoch number # to prevent inconsistencies due to different batch size in checkpoint. start_epoch = self.step // len(self.train_loader) - previous_wall_time = time.time() for epoch_int in range(start_epoch, self.config["optim"]["max_epochs"]): skip_steps = self.step % len(self.train_loader) @@ -184,9 +182,6 @@ def train(self, disable_eval_tqdm: bool = False) -> None: self.step % self.config["cmd"]["print_every"] == 0 and distutils.is_master() ): - time_delta = time.time() - previous_wall_time - previous_wall_time = time.time() - log_dict.update({'step_per_s' : self.config["cmd"]["print_every"] / time_delta}) log_str = [f"{k}: {v:.2e}" for k, v in log_dict.items()] logging.info(", ".join(log_str)) self.metrics = {} diff --git a/tests/core/models/test_escn_compiles.py b/tests/core/models/test_escn_compiles.py index dee01fcf56..1bd87abed3 100644 --- a/tests/core/models/test_escn_compiles.py +++ b/tests/core/models/test_escn_compiles.py @@ -7,33 +7,29 @@ from __future__ import annotations -import copy -import io +import logging import os import random -import numpy as np -import logging +import numpy as np import pytest -import requests import torch from ase.io import read +from torch.export import Dim, export from torch.nn.parallel.distributed import DistributedDataParallel from fairchem.core.common.registry import registry from fairchem.core.common.test_utils import init_local_distributed_process_group -from fairchem.core.common.utils import load_state_dict, setup_imports -from fairchem.core.datasets import data_list_collater -from fairchem.core.preprocessing import AtomsToGraphs from fairchem.core.common.transforms import RandomRotate -from fairchem.core.models.scn.smearing import GaussianSmearing -from fairchem.core.models.base import GraphModelMixin - -from fairchem.core.models.escn.so3_exportable import CoefficientMapping, SO3_Grid, rotation_to_wigner +from fairchem.core.common.utils import setup_imports +from fairchem.core.datasets import data_list_collater from fairchem.core.models.escn import escn_exportable - -from torch.export import export -from torch.export import Dim +from fairchem.core.models.escn.so3_exportable import ( + CoefficientMapping, + SO3_Grid, +) +from fairchem.core.models.scn.smearing import GaussianSmearing +from fairchem.core.preprocessing import AtomsToGraphs skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="skipping when no gpu") @@ -58,7 +54,7 @@ def load_data(): def load_model(name: str): torch.manual_seed(4) setup_imports() - model = registry.get_model_class(name)( + return registry.get_model_class(name)( use_pbc = True, use_pbc_single = False, regress_forces = True, @@ -78,7 +74,6 @@ def load_model(name: str): distance_resolution = 0.02, resolution = None, ) - return model def init(backend: str): if not torch.distributed.is_initialized(): @@ -86,7 +81,7 @@ def init(backend: str): class TestESCNCompiles: def test_escn_baseline_cpu(self, tol=1e-8): - init('gloo') + init("gloo") data = load_data() data_tg = data_list_collater([data]) data_export = data_list_collater([data], to_dict=True) @@ -101,7 +96,7 @@ def test_escn_baseline_cpu(self, tol=1e-8): @skip_if_no_cuda def test_escn_baseline_cuda(self, tol=1e-8): - init('nccl') + init("nccl") data = load_data() data_tg = data_list_collater([data]).to("cuda") data_export = data_list_collater([data], to_dict=True) @@ -163,12 +158,12 @@ def test_escn_so2_conv_exports_and_compiles(self, tol=1e-5) -> None: args=(torch.rand(680, 19, shpere_channels), torch.rand(680, edge_channels)) so2 = escn_exportable.SO2Block( - sphere_channels=shpere_channels, + sphere_channels=shpere_channels, hidden_channels=128, edge_channels=edge_channels, - lmax=lmax, - mmax=mmax, - act=torch.nn.SiLU(), + lmax=lmax, + mmax=mmax, + act=torch.nn.SiLU(), mappingReduced=mappingReduced ) prog = export(so2, args=args, dynamic_shapes=dynamic_shapes1) @@ -284,7 +279,7 @@ def test_escn_layer_block_exports_and_compiles(self, tol=1e-5) -> None: "atomic_numbers": {0: batch_dim}, "edge_distance": {0: edges_dim}, "edge_index": {0: None, 1: edges_dim}, - "wigner": {0: edges_dim, 1: None, 2: None} + "wigner": {0: edges_dim, 1: None, 2: None} } exported_prog = export(layer_block, args=run_args[0], dynamic_shapes=dynamic_shapes1) for run_arg in run_args: @@ -300,7 +295,7 @@ def test_full_escn_compiles(self, tol=1e-5): data = load_data() regular_data = data_list_collater([data]) compile_data = data_list_collater([data], to_dict=True) - model = load_model('escn_export') + model = load_model("escn_export") ddp_model = DistributedDataParallel(model) torch._dynamo.config.optimize_ddp = False @@ -317,7 +312,7 @@ def test_full_escn_exports(self): data = load_data() regular_data = data_list_collater([data]) export_data = data_list_collater([data], to_dict=True) - model = load_model('escn_export') + model = load_model("escn_export") torch._dynamo.config.optimize_ddp = False torch._dynamo.config.assume_static_by_default = False From f19b8e689cab7b98b2680d8aef7612241ff98986 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Mon, 9 Sep 2024 15:37:27 -0700 Subject: [PATCH 20/25] ruff --- src/fairchem/core/datasets/lmdb_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fairchem/core/datasets/lmdb_dataset.py b/src/fairchem/core/datasets/lmdb_dataset.py index ebcd08a261..346987d8e5 100644 --- a/src/fairchem/core/datasets/lmdb_dataset.py +++ b/src/fairchem/core/datasets/lmdb_dataset.py @@ -227,6 +227,6 @@ def data_list_collater(data_list: list[BaseData], otf_graph: bool = False, to_di ) if to_dict: - batch = {k:v for k,v in batch.items()} + batch = dict(batch.items()) return batch From e3490956366cf6981cade0715008e7005ad53dd7 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Mon, 9 Sep 2024 15:42:29 -0700 Subject: [PATCH 21/25] lint --- .../core/models/escn/escn_exportable.py | 1 + .../core/models/escn/so3_exportable.py | 115 ++++-------------- 2 files changed, 27 insertions(+), 89 deletions(-) diff --git a/src/fairchem/core/models/escn/escn_exportable.py b/src/fairchem/core/models/escn/escn_exportable.py index 2a651827d2..e3f234496c 100644 --- a/src/fairchem/core/models/escn/escn_exportable.py +++ b/src/fairchem/core/models/escn/escn_exportable.py @@ -104,6 +104,7 @@ def __init__( self.grad_forces = False self.lmax_list = lmax_list self.mmax_list = mmax_list + # TODO: completely remove for loops here associated with lmax and mmax lists assert len(self.lmax_list) == 1 assert len(self.mmax_list) == 1 self.lmax = lmax_list[0] diff --git a/src/fairchem/core/models/escn/so3_exportable.py b/src/fairchem/core/models/escn/so3_exportable.py index 5c9fe06bac..a5189d22df 100644 --- a/src/fairchem/core/models/escn/so3_exportable.py +++ b/src/fairchem/core/models/escn/so3_exportable.py @@ -1,9 +1,9 @@ from __future__ import annotations +import math import os import torch -import math try: from e3nn import o3 @@ -93,8 +93,10 @@ def __init__( self.mmax_list = mmax_list self.num_resolutions = len(lmax_list) - assert (len(self.lmax_list) == 1) and (len(self.mmax_list) == 1) - + # TODO: remove this for loops here associated with lmax and mmax lists + assert len(self.lmax_list) == 1 + assert len(self.mmax_list) == 1 + # Compute the degree (l) and order (m) for each entry of the embedding l_harmonic = torch.tensor([]).long() m_harmonic = torch.tensor([]).long() @@ -104,7 +106,7 @@ def __init__( offset = 0 for i in range(self.num_resolutions): - for l in range(0, self.lmax_list[i] + 1): + for l in range(self.lmax_list[i] + 1): mmax = min(self.mmax_list[i], l) m = torch.arange(-mmax, mmax + 1).long() m_complex = torch.cat([m_complex, m], dim=0) @@ -139,20 +141,20 @@ def __init__( to_m = to_m.detach() # save tensors and they will be moved to GPU - self.register_buffer('l_harmonic', l_harmonic) - self.register_buffer('m_harmonic', m_harmonic) - self.register_buffer('m_complex', m_complex) - self.register_buffer('to_m', to_m) + self.register_buffer("l_harmonic", l_harmonic) + self.register_buffer("m_harmonic", m_harmonic) + self.register_buffer("m_complex", m_complex) + self.register_buffer("to_m", to_m) self.pre_compute_coefficient_idx() # Return mask containing coefficients of order m (real and imaginary parts) def complex_idx(self, m, lmax, m_complex, l_harmonic): - ''' - Add `m_complex` and `l_harmonic` to the input arguments - since we cannot use `self.m_complex`. - ''' + """ + Add `m_complex` and `l_harmonic` to the input arguments + since we cannot use `self.m_complex`. + """ if lmax == -1: lmax = max(self.lmax_list) @@ -175,9 +177,9 @@ def complex_idx(self, m, lmax, m_complex, l_harmonic): def pre_compute_coefficient_idx(self): - ''' + """ Pre-compute the results of `coefficient_idx()` and access them with `prepare_coefficient_idx()` - ''' + """ lmax = max(self.lmax_list) for l in range(lmax + 1): for m in range(lmax + 1): @@ -186,19 +188,19 @@ def pre_compute_coefficient_idx(self): ) indices = torch.arange(len(mask)) mask_indices = torch.masked_select(indices, mask) - self.register_buffer('coefficient_idx_l{}_m{}'.format(l, m), mask_indices) + self.register_buffer(f"coefficient_idx_l{l}_m{m}", mask_indices) + - def prepare_coefficient_idx(self): - ''' + """ Construct a list of buffers - ''' + """ lmax = max(self.lmax_list) coefficient_idx_list = [] for l in range(lmax + 1): l_list = [] for m in range(lmax + 1): - l_list.append(getattr(self, 'coefficient_idx_l{}_m{}'.format(l, m), None)) + l_list.append(getattr(self, f"coefficient_idx_l{l}_m{m}", None)) coefficient_idx_list.append(l_list) return coefficient_idx_list @@ -210,12 +212,11 @@ def coefficient_idx(self, lmax: int, mmax: int): self.l_harmonic.le(lmax), self.m_harmonic.le(mmax) ) indices = torch.arange(len(mask), device=mask.device) - mask_indices = torch.masked_select(indices, mask) - return mask_indices + return torch.masked_select(indices, mask) else: temp = self.prepare_coefficient_idx() - return temp[lmax][mmax] - + return temp[lmax][mmax] + def pre_compute_rotate_inv_rescale(self): lmax = max(self.lmax_list) @@ -231,75 +232,11 @@ def pre_compute_rotate_inv_rescale(self): rescale_factor = math.sqrt(length / (2 * m + 1)) rotate_inv_rescale[:, start_idx : (start_idx + length), start_idx : (start_idx + length)] = rescale_factor rotate_inv_rescale = rotate_inv_rescale[:, :, mask_indices] - self.register_buffer('rotate_inv_rescale_l{}_m{}'.format(l, m), rotate_inv_rescale) - + self.register_buffer(f"rotate_inv_rescale_l{l}_m{m}", rotate_inv_rescale) + def __repr__(self): return f"{self.__class__.__name__}(lmax_list={self.lmax_list}, mmax_list={self.mmax_list})" -class SO3_Rotation: - """ - Helper functions for Wigner-D rotations - - Args: - rot_mat3x3 (tensor): Rotation matrix - lmax_list (list:int): List of maximum degree of the spherical harmonics - """ - - def __init__( - self, - rot_mat3x3: torch.Tensor, - lmax: list[int], - ) -> None: - super().__init__() - self.device = rot_mat3x3.device - self.dtype = rot_mat3x3.dtype - - self.wigner = self.RotationToWignerDMatrix(rot_mat3x3, 0, lmax) - self.wigner_inv = torch.transpose(self.wigner, 1, 2).contiguous() - - self.wigner = self.wigner.detach() - self.wigner_inv = self.wigner_inv.detach() - - self.lmax = lmax - import pdb;pdb.set_trace() - self.mapping = CoefficientMapping([self.lmax], [self.lmax]) - - - # Rotate the embedding - def rotate(self, embedding, out_lmax, out_mmax) -> torch.Tensor: - out_mask = self.mapping.coefficient_idx(out_lmax, out_mmax) - wigner = self.wigner[:, out_mask, :] - return torch.bmm(wigner, embedding) - - # Rotate the embedding by the inverse of the rotation matrix - def rotate_inv(self, embedding, in_lmax, in_mmax) -> torch.Tensor: - in_mask = self.mapping.coefficient_idx(in_lmax, in_mmax) - wigner_inv = self.wigner_inv[:, :, in_mask] - - return torch.bmm(wigner_inv, embedding) - - # Compute Wigner matrices from rotation matrix - def RotationToWignerDMatrix( - self, edge_rot_mat: torch.Tensor, start_lmax: int, end_lmax: int - ) -> torch.Tensor: - x = edge_rot_mat[:,:,1] - alpha, beta = o3.xyz_to_angles(x) - R = ( - o3.angles_to_matrix(alpha, beta, torch.zeros_like(alpha)).transpose(-1, -2) - @ edge_rot_mat - ) - gamma = torch.atan2(R[..., 0, 2], R[..., 0, 0]) - - size = int((end_lmax + 1) ** 2) - int((start_lmax) ** 2) - wigner = torch.zeros(len(alpha), size, size, device=self.device) - start = 0 - for lmax in range(start_lmax, end_lmax + 1): - block = wigner_D(lmax, alpha, beta, gamma) - end = start + block.size()[1] - wigner[:, start:end, start:end] = block - start = end - - return wigner.detach() class SO3_Grid(torch.nn.Module): """ From 488a6ce759362f44aee6bb0d2b28691d95c38aed Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Mon, 9 Sep 2024 16:44:45 -0700 Subject: [PATCH 22/25] revert base trainer changes --- src/fairchem/core/trainers/base_trainer.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index 7931785ef9..94becb924c 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -551,15 +551,6 @@ def load_model(self) -> None: device_ids=None if self.cpu else [self.device], ) - if self.config["optim"].get("compiles"): - torch._dynamo.config.optimize_ddp = False - torch._dynamo.config.assume_static_by_default = False - torch._dynamo.config.automatic_dynamic_shapes = True - os.environ["TORCH_LOGS"] = "recompiles" - self.model = torch.compile(self.model, dynamic=True) - torch._dynamo.config.optimize_ddp = False - logging.info("torch compiled model") - @property def _unwrapped_model(self): module = self.model From fe41e6184f64089ad7e0704c4fc6e993ddb2d982 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Mon, 9 Sep 2024 17:10:56 -0700 Subject: [PATCH 23/25] cleanup a2g --- src/fairchem/core/preprocessing/atoms_to_graphs.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/fairchem/core/preprocessing/atoms_to_graphs.py b/src/fairchem/core/preprocessing/atoms_to_graphs.py index 8d418e091d..473448de18 100644 --- a/src/fairchem/core/preprocessing/atoms_to_graphs.py +++ b/src/fairchem/core/preprocessing/atoms_to_graphs.py @@ -162,11 +162,6 @@ def get_edge_distance_vec( offsets = cell_offsets.float().view(-1, 1, 3).bmm(cell.float()).view(-1, 3) distance_vectors += offsets - # redundancy: remove zero distances - # TODO: do we need this? - # distances = distance_vectors.norm(dim=-1) - # nonzero_idx = torch.arange(len(distances))[distances != 0] - return distance_vectors def convert(self, atoms: ase.Atoms, sid=None): From 6cdabef1f545441db677b3d8cd6f6d620c13f41a Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Mon, 9 Sep 2024 17:36:48 -0700 Subject: [PATCH 24/25] address comments --- src/fairchem/core/models/escn/escn_exportable.py | 4 ---- tests/core/models/test_escn_compiles.py | 3 +-- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/src/fairchem/core/models/escn/escn_exportable.py b/src/fairchem/core/models/escn/escn_exportable.py index e3f234496c..3cd874bb0b 100644 --- a/src/fairchem/core/models/escn/escn_exportable.py +++ b/src/fairchem/core/models/escn/escn_exportable.py @@ -63,7 +63,6 @@ def __init__( use_pbc: bool = True, use_pbc_single: bool = False, regress_forces: bool = True, - otf_graph: bool = False, max_neighbors: int = 40, cutoff: float = 8.0, max_num_elements: int = 90, @@ -91,7 +90,6 @@ def __init__( self.use_pbc = use_pbc self.use_pbc_single = use_pbc_single self.cutoff = cutoff - self.otf_graph = otf_graph self.max_num_elements = max_num_elements self.hidden_channels = hidden_channels self.num_layers = num_layers @@ -101,7 +99,6 @@ def __init__( self.max_neighbors = max_neighbors self.edge_channels = edge_channels self.distance_resolution = distance_resolution - self.grad_forces = False self.lmax_list = lmax_list self.mmax_list = mmax_list # TODO: completely remove for loops here associated with lmax and mmax lists @@ -206,7 +203,6 @@ def __init__( ) - # @conditional_grad(torch.enable_grad()) def forward(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: pos: torch.Tensor = data["pos"] batch_idx: torch.Tensor = data["batch"] diff --git a/tests/core/models/test_escn_compiles.py b/tests/core/models/test_escn_compiles.py index 1bd87abed3..141e60fc4e 100644 --- a/tests/core/models/test_escn_compiles.py +++ b/tests/core/models/test_escn_compiles.py @@ -58,7 +58,6 @@ def load_model(name: str): use_pbc = True, use_pbc_single = False, regress_forces = True, - otf_graph = True, max_neighbors = 300, cutoff = 6.0, max_num_elements = 90, @@ -127,7 +126,7 @@ def test_rotation_invariance(self) -> None: # Compare predicted energies and forces (after inv-rotation). energies = out["energy"].detach() - np.testing.assert_almost_equal(energies[0], energies[1], decimal=5) + np.testing.assert_almost_equal(energies[0], energies[1], decimal=7) forces = out["forces"].detach() logging.info(forces) From 14acfc71e36a81525a028351748349f62cd1efde Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Mon, 9 Sep 2024 18:04:02 -0700 Subject: [PATCH 25/25] cleanup --- .../core/models/escn/escn_exportable.py | 29 +++-------- tests/core/models/test_escn_compiles.py | 49 +++++++++++++------ 2 files changed, 41 insertions(+), 37 deletions(-) diff --git a/src/fairchem/core/models/escn/escn_exportable.py b/src/fairchem/core/models/escn/escn_exportable.py index 3cd874bb0b..c1a40ff59c 100644 --- a/src/fairchem/core/models/escn/escn_exportable.py +++ b/src/fairchem/core/models/escn/escn_exportable.py @@ -38,14 +38,9 @@ class eSCN(nn.Module): Args: - use_pbc (bool): Use periodic boundary conditions - use_pbc_single (bool): Process batch PBC graphs one at a time - regress_forces (bool): Compute forces - otf_graph (bool): Compute graph On The Fly (OTF) - max_neighbors (int): Maximum number of neighbors per atom - cutoff (float): Maximum distance between nieghboring atoms in Angstroms + regress_forces (bool): Compute forces + cutoff (float): Maximum distance between nieghboring atoms in Angstroms max_num_elements (int): Maximum atomic number - num_layers (int): Number of layers in the GNN lmax (int): maximum degree of the spherical harmonics (1 to 10) mmax (int): maximum order of the spherical harmonics (0 to lmax) @@ -60,15 +55,12 @@ class eSCN(nn.Module): def __init__( self, - use_pbc: bool = True, - use_pbc_single: bool = False, regress_forces: bool = True, - max_neighbors: int = 40, cutoff: float = 8.0, max_num_elements: int = 90, num_layers: int = 8, - lmax_list: list[int] = (4), # list of 1, for backward compat only right now, - mmax_list: list[int] = (2), # list of 1, for backward compat only right now, + lmax: int = 4, + mmax: int = 2, sphere_channels: int = 128, hidden_channels: int = 256, edge_channels: int = 128, @@ -87,25 +79,16 @@ def __init__( raise ImportError self.regress_forces = regress_forces - self.use_pbc = use_pbc - self.use_pbc_single = use_pbc_single self.cutoff = cutoff self.max_num_elements = max_num_elements self.hidden_channels = hidden_channels self.num_layers = num_layers - self.num_atoms = 0 self.num_sphere_samples = num_sphere_samples self.sphere_channels = sphere_channels - self.max_neighbors = max_neighbors self.edge_channels = edge_channels self.distance_resolution = distance_resolution - self.lmax_list = lmax_list - self.mmax_list = mmax_list - # TODO: completely remove for loops here associated with lmax and mmax lists - assert len(self.lmax_list) == 1 - assert len(self.mmax_list) == 1 - self.lmax = lmax_list[0] - self.mmax = mmax_list[0] + self.lmax = lmax + self.mmax = mmax self.basis_width_scalar = basis_width_scalar self.distance_function = distance_function diff --git a/tests/core/models/test_escn_compiles.py b/tests/core/models/test_escn_compiles.py index 141e60fc4e..269433d4d0 100644 --- a/tests/core/models/test_escn_compiles.py +++ b/tests/core/models/test_escn_compiles.py @@ -51,10 +51,10 @@ def load_data(): return data_list[0] -def load_model(name: str): +def load_escn_model(): torch.manual_seed(4) setup_imports() - return registry.get_model_class(name)( + return registry.get_model_class("escn")( use_pbc = True, use_pbc_single = False, regress_forces = True, @@ -74,6 +74,26 @@ def load_model(name: str): resolution = None, ) +def load_escn_exportable_model(): + torch.manual_seed(4) + setup_imports() + return registry.get_model_class("escn_export")( + regress_forces = True, + cutoff = 6.0, + max_num_elements = 90, + num_layers = 8, + lmax = 4, + mmax = 2, + sphere_channels = 128, + hidden_channels = 256, + edge_channels = 128, + num_sphere_samples = 128, + distance_function = "gaussian", + basis_width_scalar = 1.0, + distance_resolution = 0.02, + resolution = None, + ) + def init(backend: str): if not torch.distributed.is_initialized(): init_local_distributed_process_group(backend=backend) @@ -85,8 +105,8 @@ def test_escn_baseline_cpu(self, tol=1e-8): data_tg = data_list_collater([data]) data_export = data_list_collater([data], to_dict=True) - base_model = DistributedDataParallel(load_model("escn")) - export_model = DistributedDataParallel(load_model("escn_export")) + base_model = DistributedDataParallel(load_escn_model()) + export_model = DistributedDataParallel(load_escn_exportable_model()) base_output = base_model(data_tg) export_output = export_model(data_export) torch.set_printoptions(precision=8) @@ -101,8 +121,8 @@ def test_escn_baseline_cuda(self, tol=1e-8): data_export = data_list_collater([data], to_dict=True) data_export_cu = {k:v.to("cuda") for k,v in data_export.items()} - base_model = DistributedDataParallel(load_model("escn").cuda()) - export_model = DistributedDataParallel(load_model("escn_export").cuda()) + base_model = DistributedDataParallel(load_escn_model().cuda()) + export_model = DistributedDataParallel(load_escn_exportable_model().cuda()) base_output = base_model(data_tg) export_output = export_model(data_export_cu) torch.set_printoptions(precision=8) @@ -120,7 +140,7 @@ def test_rotation_invariance(self) -> None: # Pass it through the model. batch = data_list_collater([data, data_rotated], to_dict=True) - model = load_model("escn_export") + model = load_escn_exportable_model() model.eval() out = model(batch) @@ -294,15 +314,15 @@ def test_full_escn_compiles(self, tol=1e-5): data = load_data() regular_data = data_list_collater([data]) compile_data = data_list_collater([data], to_dict=True) - model = load_model("escn_export") - ddp_model = DistributedDataParallel(model) + escn_model = DistributedDataParallel(load_escn_model()) + exportable_model = load_escn_exportable_model() torch._dynamo.config.optimize_ddp = False torch._dynamo.config.assume_static_by_default = False torch._dynamo.config.automatic_dynamic_shapes = True - compiled_model = torch.compile(ddp_model, dynamic=True) + compiled_model = torch.compile(exportable_model, dynamic=True) output = compiled_model(compile_data) - expected_output = ddp_model(regular_data) + expected_output = escn_model(regular_data) assert torch.allclose(expected_output["energy"], output["energy"], atol=tol) assert torch.allclose(expected_output["forces"].mean(0), output["forces"].mean(0), atol=tol) @@ -311,7 +331,8 @@ def test_full_escn_exports(self): data = load_data() regular_data = data_list_collater([data]) export_data = data_list_collater([data], to_dict=True) - model = load_model("escn_export") + escn_model = load_escn_model() + exportable_model = load_escn_exportable_model() torch._dynamo.config.optimize_ddp = False torch._dynamo.config.assume_static_by_default = False @@ -321,8 +342,8 @@ def test_full_escn_exports(self): # explained_output = torch._dynamo.explain(model)(*data) # print(explained_output) # TODO: add dynamic shapes - exported_prog = export(model, args=(export_data,)) + exported_prog = export(exportable_model, args=(export_data,)) export_output = exported_prog(export_data) - expected_output = model(regular_data) + expected_output = escn_model(regular_data) assert torch.allclose(export_output["energy"], expected_output["energy"]) assert torch.allclose(export_output["forces"].mean(0), expected_output["forces"].mean(0))