diff --git a/chgnet/model/functions.py b/chgnet/model/functions.py index 78fa180d..19345101 100644 --- a/chgnet/model/functions.py +++ b/chgnet/model/functions.py @@ -49,6 +49,7 @@ def __init__( hidden_dim: int | Sequence[int] | None = (64, 64), dropout: float = 0, activation: str = "silu", + bias: bool = True, ) -> None: """Initialize the MLP. @@ -61,26 +62,31 @@ def __init__( dropout (float): the dropout rate before each linear layer. Default: 0 activation (str, optional): The name of the activation function to use in the gated MLP. Must be one of "relu", "silu", "tanh", or "gelu". - Default = "silu". + Default = "silu" + bias (bool): whether to use bias in each Linear layers. + Default = True """ super().__init__() if hidden_dim in (None, 0): - layers = [nn.Dropout(dropout), nn.Linear(input_dim, output_dim)] + layers = [nn.Dropout(dropout), nn.Linear(input_dim, output_dim, bias=bias)] elif isinstance(hidden_dim, int): layers = [ - nn.Linear(input_dim, hidden_dim), + nn.Linear(input_dim, hidden_dim, bias=bias), find_activation(activation), nn.Dropout(dropout), - nn.Linear(hidden_dim, output_dim), + nn.Linear(hidden_dim, output_dim, bias=bias), ] elif isinstance(hidden_dim, Sequence): - layers = [nn.Linear(input_dim, hidden_dim[0]), find_activation(activation)] + layers = [ + nn.Linear(input_dim, hidden_dim[0], bias=bias), + find_activation(activation), + ] if len(hidden_dim) != 1: for h_in, h_out in zip(hidden_dim[0:-1], hidden_dim[1:]): - layers.append(nn.Linear(h_in, h_out)) + layers.append(nn.Linear(h_in, h_out, bias=bias)) layers.append(find_activation(activation)) layers.append(nn.Dropout(dropout)) - layers.append(nn.Linear(hidden_dim[-1], output_dim)) + layers.append(nn.Linear(hidden_dim[-1], output_dim, bias=bias)) else: raise TypeError( f"{hidden_dim=} must be an integer, a list of integers, or None." @@ -109,22 +115,29 @@ def __init__( input_dim: int, output_dim: int, hidden_dim: int | list[int] | None = None, - dropout=0, - activation="silu", - norm="batch", + dropout: float = 0, + activation: str = "silu", + norm: str = "batch", + bias: bool = True, ) -> None: """Initialize a gated MLP. Args: input_dim (int): the input dimension output_dim (int): the output dimension - hidden_dim (list[int] | int]): a list of integers or a single integer representing - the number of hidden units in each layer of the MLP. Default = None - dropout (float): the dropout rate before each linear layer. Default: 0 - activation (str, optional): The name of the activation function to use in the gated - MLP. Must be one of "relu", "silu", "tanh", or "gelu". Default = "silu". - norm (str, optional): The name of the normalization layer to use on the updated - atom features. Must be one of "batch", "layer", or None. Default = "batch". + hidden_dim (list[int] | int]): a list of integers or a single integer + representing the number of hidden units in each layer of the MLP. + Default = None + dropout (float): the dropout rate before each linear layer. + Default: 0 + activation (str, optional): The name of the activation function to use in + the gated MLP. Must be one of "relu", "silu", "tanh", or "gelu". + Default = "silu" + norm (str, optional): The name of the normalization layer to use on the + updated atom features. Must be one of "batch", "layer", or None. + Default = "batch" + bias (bool): whether to use bias in each Linear layers. + Default = True """ super().__init__() self.mlp_core = MLP( @@ -133,6 +146,7 @@ def __init__( hidden_dim=hidden_dim, dropout=dropout, activation=activation, + bias=bias, ) self.mlp_gate = MLP( input_dim=input_dim, @@ -140,6 +154,7 @@ def __init__( hidden_dim=hidden_dim, dropout=dropout, activation=activation, + bias=bias, ) self.activation = find_activation(activation) self.sigmoid = nn.Sigmoid() diff --git a/chgnet/model/layers.py b/chgnet/model/layers.py index 1b5c859a..6a9f38cf 100644 --- a/chgnet/model/layers.py +++ b/chgnet/model/layers.py @@ -24,6 +24,7 @@ def __init__( activation: str = "silu", norm: str | None = None, use_mlp_out: bool = True, + mlp_out_bias: bool = False, resnet: bool = True, gMLP_norm: str | None = None, ) -> None: @@ -46,6 +47,8 @@ def __init__( use_mlp_out (bool, optional): Whether to apply an MLP output layer to the updated atom features. Default = True + mlp_out_bias (bool): whether to use bias in the output MLP Linear layer. + Default = False resnet (bool, optional): Whether to apply a residual connection to the updated atom features. Default = True @@ -67,7 +70,10 @@ def __init__( ) if self.use_mlp_out: self.mlp_out = MLP( - input_dim=atom_fea_dim, output_dim=atom_fea_dim, hidden_dim=0 + input_dim=atom_fea_dim, + output_dim=atom_fea_dim, + hidden_dim=0, + bias=mlp_out_bias, ) self.atom_norm = find_normalization(name=norm, dim=atom_fea_dim) @@ -143,6 +149,7 @@ def __init__( activation: str = "silu", norm: str | None = None, use_mlp_out: bool = True, + mlp_out_bias: bool = False, resnet=True, gMLP_norm: str | None = None, ) -> None: @@ -166,6 +173,8 @@ def __init__( use_mlp_out (bool, optional): Whether to apply an MLP output layer to the updated atom features. Default = True + mlp_out_bias (bool): whether to use bias in the output MLP Linear layer. + Default = False resnet (bool, optional): Whether to apply a residual connection to the updated atom features. Default = True @@ -187,7 +196,10 @@ def __init__( ) if self.use_mlp_out: self.mlp_out = MLP( - input_dim=bond_fea_dim, output_dim=bond_fea_dim, hidden_dim=0 + input_dim=bond_fea_dim, + output_dim=bond_fea_dim, + hidden_dim=0, + bias=mlp_out_bias, ) self.bond_norm = find_normalization(name=norm, dim=bond_fea_dim) @@ -238,10 +250,11 @@ def forward( new_bond_feas = aggregate( bond_update, bond_graph[:, 1], average=False, num_owner=len(bond_feas) ) - + print("before mlp", new_bond_feas) # New bond features if self.use_mlp_out: new_bond_feas = self.mlp_out(new_bond_feas) + print("after", new_bond_feas) if self.resnet: new_bond_feas += bond_feas diff --git a/chgnet/model/model.py b/chgnet/model/model.py index 383f2f00..6b0bd0e1 100644 --- a/chgnet/model/model.py +++ b/chgnet/model/model.py @@ -201,6 +201,7 @@ def __init__( # Define convolutional layers conv_norm = kwargs.pop("conv_norm", None) gMLP_norm = kwargs.pop("gMLP_norm", None) + mlp_out_bias = kwargs.pop("mlp_out_bias", False) atom_graph_layers = [ AtomConv( atom_fea_dim=atom_fea_dim, @@ -211,6 +212,7 @@ def __init__( norm=conv_norm, gMLP_norm=gMLP_norm, use_mlp_out=True, + mlp_out_bias=mlp_out_bias, resnet=True, ) for _ in range(n_conv) @@ -229,6 +231,7 @@ def __init__( norm=conv_norm, gMLP_norm=gMLP_norm, use_mlp_out=True, + mlp_out_bias=mlp_out_bias, resnet=True, ) for _ in range(n_conv - 1) @@ -636,8 +639,8 @@ def todict(self): @classmethod def from_dict(cls, dict, **kwargs): """Build a CHGNet from a saved dictionary.""" - chgnet = CHGNet(**dict["model_args"]) - chgnet.load_state_dict(dict["state_dict"], **kwargs) + chgnet = CHGNet(**dict["model_args"], **kwargs) + chgnet.load_state_dict(dict["state_dict"]) return chgnet @classmethod @@ -652,7 +655,8 @@ def load(cls, model_name="MPtrj-efsm"): current_dir = os.path.dirname(os.path.abspath(__file__)) if model_name == "MPtrj-efsm": return cls.from_file( - os.path.join(current_dir, "../pretrained/e30f77s348m32.pth.tar") + os.path.join(current_dir, "../pretrained/e30f77s348m32.pth.tar"), + mlp_out_bias=True, ) raise ValueError(f"Unknown {model_name=}")