diff --git a/torchdrug/models/chebnet.py b/torchdrug/models/chebnet.py index cb1793d..86d2aef 100644 --- a/torchdrug/models/chebnet.py +++ b/torchdrug/models/chebnet.py @@ -25,7 +25,7 @@ class ChebyshevConvolutionalNetwork(nn.Module, core.Configurable): batch_norm (bool, optional): apply batch normalization or not activation (str or function, optional): activation function concat_hidden (bool, optional): concat hidden representations from all layers as output - readout (str, optional): readout function. Available functions are ``sum`` and ``mean``. + readout: readout function. Available functions are ``sum`` and ``mean``. """ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, k=1, short_cut=False, batch_norm=False, @@ -49,6 +49,8 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, k=1, short_cut=F self.readout = layers.SumReadout() elif readout == "mean": self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() else: raise ValueError("Unknown readout `%s`" % readout) diff --git a/torchdrug/models/gat.py b/torchdrug/models/gat.py index 5aca144..c22e848 100644 --- a/torchdrug/models/gat.py +++ b/torchdrug/models/gat.py @@ -25,7 +25,7 @@ class GraphAttentionNetwork(nn.Module, core.Configurable): batch_norm (bool, optional): apply batch normalization or not activation (str or function, optional): activation function concat_hidden (bool, optional): concat hidden representations from all layers as output - readout (str, optional): readout function. Available functions are ``sum`` and ``mean``. + readout: readout function. Available functions are ``sum``, ``mean``, and ``max``. """ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, num_head=1, negative_slope=0.2, short_cut=False, @@ -49,6 +49,8 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, num_head=1, nega self.readout = layers.SumReadout() elif readout == "mean": self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() else: raise ValueError("Unknown readout `%s`" % readout) @@ -85,4 +87,4 @@ def forward(self, graph, input, all_loss=None, metric=None): return { "graph_feature": graph_feature, "node_feature": node_feature - } \ No newline at end of file + } diff --git a/torchdrug/models/gcn.py b/torchdrug/models/gcn.py index a4fd3b9..352c39b 100644 --- a/torchdrug/models/gcn.py +++ b/torchdrug/models/gcn.py @@ -23,7 +23,7 @@ class GraphConvolutionalNetwork(nn.Module, core.Configurable): batch_norm (bool, optional): apply batch normalization or not activation (str or function, optional): activation function concat_hidden (bool, optional): concat hidden representations from all layers as output - readout (str, optional): readout function. Available functions are ``sum`` and ``mean``. + readout: readout function. Available functions are ``sum``, ``mean``, and ``max``. """ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, short_cut=False, batch_norm=False, @@ -46,6 +46,8 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, short_cut=False, self.readout = layers.SumReadout() elif readout == "mean": self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() else: raise ValueError("Unknown readout `%s`" % readout) @@ -102,7 +104,7 @@ class RelationalGraphConvolutionalNetwork(nn.Module, core.Configurable): batch_norm (bool, optional): apply batch normalization or not activation (str or function, optional): activation function concat_hidden (bool, optional): concat hidden representations from all layers as output - readout (str, optional): readout function. Available functions are ``sum`` and ``mean``. + readout: readout function. Available functions are ``sum``, ``mean``, and ``max``. """ def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, short_cut=False, batch_norm=False, @@ -127,6 +129,8 @@ def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, sh self.readout = layers.SumReadout() elif readout == "mean": self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() else: raise ValueError("Unknown readout `%s`" % readout) @@ -165,4 +169,4 @@ def forward(self, graph, input, all_loss=None, metric=None): return { "graph_feature": graph_feature, "node_feature": node_feature - } \ No newline at end of file + } diff --git a/torchdrug/models/gin.py b/torchdrug/models/gin.py index 70e217d..f0d99de 100644 --- a/torchdrug/models/gin.py +++ b/torchdrug/models/gin.py @@ -26,12 +26,11 @@ class GraphIsomorphismNetwork(nn.Module, core.Configurable): batch_norm (bool, optional): apply batch normalization or not activation (str or function, optional): activation function concat_hidden (bool, optional): concat hidden representations from all layers as output - readout (str, optional): readout function. Available functions are ``sum`` and ``mean``. + readout: readout function. Available functions are ``sum``, ``mean``, and ``max``. """ def __init__(self, input_dim=None, hidden_dims=None, edge_input_dim=None, num_mlp_layer=2, eps=0, learn_eps=False, - short_cut=False, batch_norm=False, activation="relu", concat_hidden=False, - readout="sum"): + short_cut=False, batch_norm=False, activation="relu", concat_hidden=False, readout="sum"): super(GraphIsomorphismNetwork, self).__init__() if not isinstance(hidden_dims, Sequence): @@ -52,6 +51,8 @@ def __init__(self, input_dim=None, hidden_dims=None, edge_input_dim=None, num_ml self.readout = layers.SumReadout() elif readout == "mean": self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() else: raise ValueError("Unknown readout `%s`" % readout) @@ -88,4 +89,4 @@ def forward(self, graph, input, all_loss=None, metric=None): return { "graph_feature": graph_feature, "node_feature": node_feature - } \ No newline at end of file + } diff --git a/torchdrug/models/neuralfp.py b/torchdrug/models/neuralfp.py index ec47c2c..b4e5979 100644 --- a/torchdrug/models/neuralfp.py +++ b/torchdrug/models/neuralfp.py @@ -25,7 +25,7 @@ class NeuralFingerprint(nn.Module, core.Configurable): batch_norm (bool, optional): apply batch normalization or not activation (str or function, optional): activation function concat_hidden (bool, optional): concat hidden representations from all layers as output - readout (str, optional): readout function. Available functions are ``sum`` and ``mean``. + readout: readout function. Available functions are ``sum``, ``mean``, and ``max``. """ def __init__(self, input_dim, output_dim, hidden_dims, edge_input_dim=None, short_cut=False, batch_norm=False, @@ -51,6 +51,8 @@ def __init__(self, input_dim, output_dim, hidden_dims, edge_input_dim=None, shor self.readout = layers.SumReadout() elif readout == "mean": self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() else: raise ValueError("Unknown readout `%s`" % readout) diff --git a/torchdrug/models/schnet.py b/torchdrug/models/schnet.py index 0bfb2cf..1a7cf09 100644 --- a/torchdrug/models/schnet.py +++ b/torchdrug/models/schnet.py @@ -25,10 +25,11 @@ class SchNet(nn.Module, core.Configurable): batch_norm (bool, optional): apply batch normalization or not activation (str or function, optional): activation function concat_hidden (bool, optional): concat hidden representations from all layers as output + readout: readout function. Available functions are ``sum``, ``mean``, and ``max``. """ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, cutoff=5, num_gaussian=100, short_cut=True, - batch_norm=False, activation="shifted_softplus", concat_hidden=False): + batch_norm=False, activation="shifted_softplus", concat_hidden=False, readout="sum"): super(SchNet, self).__init__() if not isinstance(hidden_dims, Sequence): @@ -44,7 +45,14 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, cutoff=5, num_ga self.layers.append(layers.ContinuousFilterConv(self.dims[i], self.dims[i + 1], edge_input_dim, None, cutoff, num_gaussian, batch_norm, activation)) - self.readout = layers.SumReadout() + if readout == "sum": + self.readout = layers.SumReadout() + elif readout == "mean": + self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() + else: + raise ValueError("Unknown readout `%s`" % readout) def forward(self, graph, input, all_loss=None, metric=None): """ diff --git a/torchdrug/tasks/pretrain.py b/torchdrug/tasks/pretrain.py index 5e8deec..a26bc5a 100644 --- a/torchdrug/tasks/pretrain.py +++ b/torchdrug/tasks/pretrain.py @@ -169,6 +169,7 @@ class ContextPrediction(tasks.Task, core.Configurable): r2 (int, optional): outer radius for context graphs readout (nn.Module, optional): readout function over context anchor nodes num_negative (int, optional): number of negative samples per positive sample + readout: readout function. Available functions are ``sum``, ``mean``, and ``max``. """ def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout="mean", num_negative=1): @@ -184,10 +185,13 @@ def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout="mean", n self.context_model = copy.deepcopy(model) else: self.context_model = context_model + if readout == "sum": self.readout = layers.SumReadout() elif readout == "mean": self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() else: raise ValueError("Unknown readout `%s`" % readout)