Skip to content

Commit

Permalink
Explicitly initialize weights even if initialization method is "unifo…
Browse files Browse the repository at this point in the history
…rm" (#809)

* initialize even if uniform is requested

* update syrupy
  • Loading branch information
misko authored Aug 16, 2024
1 parent 96cf75d commit 2078e48
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
4 changes: 3 additions & 1 deletion src/fairchem/core/models/equiformer_v2/equiformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,8 @@ def _init_weights(self, m):
if self.weight_init == "normal":
std = 1 / math.sqrt(m.in_features)
torch.nn.init.normal_(m.weight, 0, std)
elif self.weight_init == "uniform":
self._uniform_init_linear_weights(m)

elif isinstance(m, torch.nn.LayerNorm):
torch.nn.init.constant_(m.bias, 0)
Expand All @@ -647,7 +649,7 @@ def _uniform_init_rad_func_linear_weights(self, m):
m.apply(self._uniform_init_linear_weights)

def _uniform_init_linear_weights(self, m):
if isinstance(m, torch.nn.Linear):
if isinstance(m, (torch.nn.Linear, SO3_LinearV2)):
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
std = 1 / math.sqrt(m.in_features)
Expand Down
4 changes: 2 additions & 2 deletions tests/core/models/__snapshots__/test_equiformer_v2.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
# ---
# name: TestEquiformerV2.test_gp.1
Approx(
array([0.12408741], dtype=float32),
array([-0.03269595], dtype=float32),
rtol=0.001,
atol=0.001
)
Expand All @@ -69,7 +69,7 @@
# ---
# name: TestEquiformerV2.test_gp.3
Approx(
array([ 1.4928658e-03, -7.4134972e-05, 2.9909210e-03], dtype=float32),
array([ 0.00208857, -0.00017979, -0.0028318 ], dtype=float32),
rtol=0.001,
atol=0.001
)
Expand Down

0 comments on commit 2078e48

Please sign in to comment.