From 997b39d5cd8f8a2b90a6fd1059c55988bee7075b Mon Sep 17 00:00:00 2001 From: Brandon Date: Mon, 2 Sep 2024 20:15:57 +0000 Subject: [PATCH 1/2] updated gemnet hydra force head to work with amp --- src/fairchem/core/models/gemnet_oc/gemnet_oc.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/fairchem/core/models/gemnet_oc/gemnet_oc.py b/src/fairchem/core/models/gemnet_oc/gemnet_oc.py index c5e6efb00..988908779 100644 --- a/src/fairchem/core/models/gemnet_oc/gemnet_oc.py +++ b/src/fairchem/core/models/gemnet_oc/gemnet_oc.py @@ -1489,12 +1489,17 @@ def forward( @registry.register_model("gemnet_oc_force_head") class GemNetOCForceHead(nn.Module, HeadInterface): def __init__( - self, backbone, num_global_out_layers: int, output_init: str = "HeOrthogonal" + self, + backbone, + num_global_out_layers: int, + use_amp: bool = True, + output_init: str = "HeOrthogonal", ): super().__init__() self.direct_forces = backbone.direct_forces self.forces_coupled = backbone.forces_coupled + self._use_amp = use_amp emb_size_edge = backbone.edge_emb.dense.linear.out_features if self.direct_forces: @@ -1523,11 +1528,18 @@ def __init__( out_initializer = get_initializer(output_init) self.out_forces.reset_parameters(out_initializer) + @property + def use_amp(self): + return self._use_amp + def forward( self, data: Batch, emb: dict[str, torch.Tensor] ) -> dict[str, torch.Tensor]: if self.direct_forces: - x_F = self.out_mlp_F(torch.cat(emb["xs_F"], dim=-1)) + if self.use_amp: + x_F = self.out_mlp_F(torch.cat(emb["xs_F"], dim=-1)) + else: + x_F = self.out_mlp_F(torch.cat(emb["xs_F"], dim=-1).float()) with torch.cuda.amp.autocast(False): F_st = self.out_forces(x_F.float()) From 953d8848dc41b306a6a7fd3bf43b100ae32f76ef Mon Sep 17 00:00:00 2001 From: Brandon Date: Wed, 4 Sep 2024 19:32:45 +0000 Subject: [PATCH 2/2] remove use_amp --- src/fairchem/core/models/gemnet_oc/gemnet_oc.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/src/fairchem/core/models/gemnet_oc/gemnet_oc.py b/src/fairchem/core/models/gemnet_oc/gemnet_oc.py index 988908779..c982b7d43 100644 --- a/src/fairchem/core/models/gemnet_oc/gemnet_oc.py +++ b/src/fairchem/core/models/gemnet_oc/gemnet_oc.py @@ -1492,14 +1492,12 @@ def __init__( self, backbone, num_global_out_layers: int, - use_amp: bool = True, output_init: str = "HeOrthogonal", ): super().__init__() self.direct_forces = backbone.direct_forces self.forces_coupled = backbone.forces_coupled - self._use_amp = use_amp emb_size_edge = backbone.edge_emb.dense.linear.out_features if self.direct_forces: @@ -1528,20 +1526,13 @@ def __init__( out_initializer = get_initializer(output_init) self.out_forces.reset_parameters(out_initializer) - @property - def use_amp(self): - return self._use_amp - def forward( self, data: Batch, emb: dict[str, torch.Tensor] ) -> dict[str, torch.Tensor]: if self.direct_forces: - if self.use_amp: - x_F = self.out_mlp_F(torch.cat(emb["xs_F"], dim=-1)) - else: - x_F = self.out_mlp_F(torch.cat(emb["xs_F"], dim=-1).float()) with torch.cuda.amp.autocast(False): - F_st = self.out_forces(x_F.float()) + x_F = self.out_mlp_F(torch.cat(emb["xs_F"], dim=-1).float()) + F_st = self.out_forces(x_F) if self.forces_coupled: # enforce F_st = F_ts nEdges = emb["edge_idx"].shape[0]