Skip to content

Commit

Permalink
Fixed bug in #79, while not breaking loading pretrained weights
Browse files Browse the repository at this point in the history
  • Loading branch information
bowen-bd committed Oct 6, 2023
1 parent 4b7286e commit 1a558ea
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 23 deletions.
49 changes: 32 additions & 17 deletions chgnet/model/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
hidden_dim: int | Sequence[int] | None = (64, 64),
dropout: float = 0,
activation: str = "silu",
bias: bool = True,
) -> None:
"""Initialize the MLP.
Expand All @@ -61,26 +62,31 @@ def __init__(
dropout (float): the dropout rate before each linear layer. Default: 0
activation (str, optional): The name of the activation function to use
in the gated MLP. Must be one of "relu", "silu", "tanh", or "gelu".
Default = "silu".
Default = "silu"
bias (bool): whether to use bias in each Linear layers.
Default = True
"""
super().__init__()
if hidden_dim in (None, 0):
layers = [nn.Dropout(dropout), nn.Linear(input_dim, output_dim)]
layers = [nn.Dropout(dropout), nn.Linear(input_dim, output_dim, bias=bias)]
elif isinstance(hidden_dim, int):
layers = [
nn.Linear(input_dim, hidden_dim),
nn.Linear(input_dim, hidden_dim, bias=bias),
find_activation(activation),
nn.Dropout(dropout),
nn.Linear(hidden_dim, output_dim),
nn.Linear(hidden_dim, output_dim, bias=bias),
]
elif isinstance(hidden_dim, Sequence):
layers = [nn.Linear(input_dim, hidden_dim[0]), find_activation(activation)]
layers = [
nn.Linear(input_dim, hidden_dim[0], bias=bias),
find_activation(activation),
]
if len(hidden_dim) != 1:
for h_in, h_out in zip(hidden_dim[0:-1], hidden_dim[1:]):
layers.append(nn.Linear(h_in, h_out))
layers.append(nn.Linear(h_in, h_out, bias=bias))
layers.append(find_activation(activation))
layers.append(nn.Dropout(dropout))
layers.append(nn.Linear(hidden_dim[-1], output_dim))
layers.append(nn.Linear(hidden_dim[-1], output_dim, bias=bias))
else:
raise TypeError(
f"{hidden_dim=} must be an integer, a list of integers, or None."
Expand Down Expand Up @@ -109,22 +115,29 @@ def __init__(
input_dim: int,
output_dim: int,
hidden_dim: int | list[int] | None = None,
dropout=0,
activation="silu",
norm="batch",
dropout: float = 0,
activation: str = "silu",
norm: str = "batch",
bias: bool = True,
) -> None:
"""Initialize a gated MLP.
Args:
input_dim (int): the input dimension
output_dim (int): the output dimension
hidden_dim (list[int] | int]): a list of integers or a single integer representing
the number of hidden units in each layer of the MLP. Default = None
dropout (float): the dropout rate before each linear layer. Default: 0
activation (str, optional): The name of the activation function to use in the gated
MLP. Must be one of "relu", "silu", "tanh", or "gelu". Default = "silu".
norm (str, optional): The name of the normalization layer to use on the updated
atom features. Must be one of "batch", "layer", or None. Default = "batch".
hidden_dim (list[int] | int]): a list of integers or a single integer
representing the number of hidden units in each layer of the MLP.
Default = None
dropout (float): the dropout rate before each linear layer.
Default: 0
activation (str, optional): The name of the activation function to use in
the gated MLP. Must be one of "relu", "silu", "tanh", or "gelu".
Default = "silu"
norm (str, optional): The name of the normalization layer to use on the
updated atom features. Must be one of "batch", "layer", or None.
Default = "batch"
bias (bool): whether to use bias in each Linear layers.
Default = True
"""
super().__init__()
self.mlp_core = MLP(
Expand All @@ -133,13 +146,15 @@ def __init__(
hidden_dim=hidden_dim,
dropout=dropout,
activation=activation,
bias=bias,
)
self.mlp_gate = MLP(
input_dim=input_dim,
output_dim=output_dim,
hidden_dim=hidden_dim,
dropout=dropout,
activation=activation,
bias=bias,
)
self.activation = find_activation(activation)
self.sigmoid = nn.Sigmoid()
Expand Down
19 changes: 16 additions & 3 deletions chgnet/model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
activation: str = "silu",
norm: str | None = None,
use_mlp_out: bool = True,
mlp_out_bias: bool = False,
resnet: bool = True,
gMLP_norm: str | None = None,
) -> None:
Expand All @@ -46,6 +47,8 @@ def __init__(
use_mlp_out (bool, optional): Whether to apply an MLP output layer to the
updated atom features.
Default = True
mlp_out_bias (bool): whether to use bias in the output MLP Linear layer.
Default = False
resnet (bool, optional): Whether to apply a residual connection to the
updated atom features.
Default = True
Expand All @@ -67,7 +70,10 @@ def __init__(
)
if self.use_mlp_out:
self.mlp_out = MLP(
input_dim=atom_fea_dim, output_dim=atom_fea_dim, hidden_dim=0
input_dim=atom_fea_dim,
output_dim=atom_fea_dim,
hidden_dim=0,
bias=mlp_out_bias,
)
self.atom_norm = find_normalization(name=norm, dim=atom_fea_dim)

Expand Down Expand Up @@ -143,6 +149,7 @@ def __init__(
activation: str = "silu",
norm: str | None = None,
use_mlp_out: bool = True,
mlp_out_bias: bool = False,
resnet=True,
gMLP_norm: str | None = None,
) -> None:
Expand All @@ -166,6 +173,8 @@ def __init__(
use_mlp_out (bool, optional): Whether to apply an MLP output layer to the
updated atom features.
Default = True
mlp_out_bias (bool): whether to use bias in the output MLP Linear layer.
Default = False
resnet (bool, optional): Whether to apply a residual connection to the
updated atom features.
Default = True
Expand All @@ -187,7 +196,10 @@ def __init__(
)
if self.use_mlp_out:
self.mlp_out = MLP(
input_dim=bond_fea_dim, output_dim=bond_fea_dim, hidden_dim=0
input_dim=bond_fea_dim,
output_dim=bond_fea_dim,
hidden_dim=0,
bias=mlp_out_bias,
)
self.bond_norm = find_normalization(name=norm, dim=bond_fea_dim)

Expand Down Expand Up @@ -238,10 +250,11 @@ def forward(
new_bond_feas = aggregate(
bond_update, bond_graph[:, 1], average=False, num_owner=len(bond_feas)
)

print("before mlp", new_bond_feas)
# New bond features
if self.use_mlp_out:
new_bond_feas = self.mlp_out(new_bond_feas)
print("after", new_bond_feas)
if self.resnet:
new_bond_feas += bond_feas

Expand Down
10 changes: 7 additions & 3 deletions chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def __init__(
# Define convolutional layers
conv_norm = kwargs.pop("conv_norm", None)
gMLP_norm = kwargs.pop("gMLP_norm", None)
mlp_out_bias = kwargs.pop("mlp_out_bias", False)
atom_graph_layers = [
AtomConv(
atom_fea_dim=atom_fea_dim,
Expand All @@ -211,6 +212,7 @@ def __init__(
norm=conv_norm,
gMLP_norm=gMLP_norm,
use_mlp_out=True,
mlp_out_bias=mlp_out_bias,
resnet=True,
)
for _ in range(n_conv)
Expand All @@ -229,6 +231,7 @@ def __init__(
norm=conv_norm,
gMLP_norm=gMLP_norm,
use_mlp_out=True,
mlp_out_bias=mlp_out_bias,
resnet=True,
)
for _ in range(n_conv - 1)
Expand Down Expand Up @@ -636,8 +639,8 @@ def todict(self):
@classmethod
def from_dict(cls, dict, **kwargs):
"""Build a CHGNet from a saved dictionary."""
chgnet = CHGNet(**dict["model_args"])
chgnet.load_state_dict(dict["state_dict"], **kwargs)
chgnet = CHGNet(**dict["model_args"], **kwargs)
chgnet.load_state_dict(dict["state_dict"])
return chgnet

@classmethod
Expand All @@ -652,7 +655,8 @@ def load(cls, model_name="MPtrj-efsm"):
current_dir = os.path.dirname(os.path.abspath(__file__))
if model_name == "MPtrj-efsm":
return cls.from_file(
os.path.join(current_dir, "../pretrained/e30f77s348m32.pth.tar")
os.path.join(current_dir, "../pretrained/e30f77s348m32.pth.tar"),
mlp_out_bias=True,
)
raise ValueError(f"Unknown {model_name=}")

Expand Down

0 comments on commit 1a558ea

Please sign in to comment.