Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use class-resolver for readouts and enable choosing MaxReadout #68

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion torchdrug/models/chebnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class ChebyshevConvolutionalNetwork(nn.Module, core.Configurable):
batch_norm (bool, optional): apply batch normalization or not
activation (str or function, optional): activation function
concat_hidden (bool, optional): concat hidden representations from all layers as output
readout (str, optional): readout function. Available functions are ``sum`` and ``mean``.
readout: readout function. Available functions are ``sum`` and ``mean``.
"""

def __init__(self, input_dim, hidden_dims, edge_input_dim=None, k=1, short_cut=False, batch_norm=False,
Expand All @@ -49,6 +49,8 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, k=1, short_cut=F
self.readout = layers.SumReadout()
elif readout == "mean":
self.readout = layers.MeanReadout()
elif readout == "max":
self.readout = layers.MaxReadout()
else:
raise ValueError("Unknown readout `%s`" % readout)

Expand Down
6 changes: 4 additions & 2 deletions torchdrug/models/gat.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class GraphAttentionNetwork(nn.Module, core.Configurable):
batch_norm (bool, optional): apply batch normalization or not
activation (str or function, optional): activation function
concat_hidden (bool, optional): concat hidden representations from all layers as output
readout (str, optional): readout function. Available functions are ``sum`` and ``mean``.
readout: readout function. Available functions are ``sum``, ``mean``, and ``max``.
"""

def __init__(self, input_dim, hidden_dims, edge_input_dim=None, num_head=1, negative_slope=0.2, short_cut=False,
Expand All @@ -49,6 +49,8 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, num_head=1, nega
self.readout = layers.SumReadout()
elif readout == "mean":
self.readout = layers.MeanReadout()
elif readout == "max":
self.readout = layers.MaxReadout()
else:
raise ValueError("Unknown readout `%s`" % readout)

Expand Down Expand Up @@ -85,4 +87,4 @@ def forward(self, graph, input, all_loss=None, metric=None):
return {
"graph_feature": graph_feature,
"node_feature": node_feature
}
}
10 changes: 7 additions & 3 deletions torchdrug/models/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class GraphConvolutionalNetwork(nn.Module, core.Configurable):
batch_norm (bool, optional): apply batch normalization or not
activation (str or function, optional): activation function
concat_hidden (bool, optional): concat hidden representations from all layers as output
readout (str, optional): readout function. Available functions are ``sum`` and ``mean``.
readout: readout function. Available functions are ``sum``, ``mean``, and ``max``.
"""

def __init__(self, input_dim, hidden_dims, edge_input_dim=None, short_cut=False, batch_norm=False,
Expand All @@ -46,6 +46,8 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, short_cut=False,
self.readout = layers.SumReadout()
elif readout == "mean":
self.readout = layers.MeanReadout()
elif readout == "max":
self.readout = layers.MaxReadout()
else:
raise ValueError("Unknown readout `%s`" % readout)

Expand Down Expand Up @@ -102,7 +104,7 @@ class RelationalGraphConvolutionalNetwork(nn.Module, core.Configurable):
batch_norm (bool, optional): apply batch normalization or not
activation (str or function, optional): activation function
concat_hidden (bool, optional): concat hidden representations from all layers as output
readout (str, optional): readout function. Available functions are ``sum`` and ``mean``.
readout: readout function. Available functions are ``sum``, ``mean``, and ``max``.
"""

def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, short_cut=False, batch_norm=False,
Expand All @@ -127,6 +129,8 @@ def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, sh
self.readout = layers.SumReadout()
elif readout == "mean":
self.readout = layers.MeanReadout()
elif readout == "max":
self.readout = layers.MaxReadout()
else:
raise ValueError("Unknown readout `%s`" % readout)

Expand Down Expand Up @@ -165,4 +169,4 @@ def forward(self, graph, input, all_loss=None, metric=None):
return {
"graph_feature": graph_feature,
"node_feature": node_feature
}
}
9 changes: 5 additions & 4 deletions torchdrug/models/gin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@ class GraphIsomorphismNetwork(nn.Module, core.Configurable):
batch_norm (bool, optional): apply batch normalization or not
activation (str or function, optional): activation function
concat_hidden (bool, optional): concat hidden representations from all layers as output
readout (str, optional): readout function. Available functions are ``sum`` and ``mean``.
readout: readout function. Available functions are ``sum``, ``mean``, and ``max``.
"""

def __init__(self, input_dim=None, hidden_dims=None, edge_input_dim=None, num_mlp_layer=2, eps=0, learn_eps=False,
short_cut=False, batch_norm=False, activation="relu", concat_hidden=False,
readout="sum"):
short_cut=False, batch_norm=False, activation="relu", concat_hidden=False, readout="sum"):
super(GraphIsomorphismNetwork, self).__init__()

if not isinstance(hidden_dims, Sequence):
Expand All @@ -52,6 +51,8 @@ def __init__(self, input_dim=None, hidden_dims=None, edge_input_dim=None, num_ml
self.readout = layers.SumReadout()
elif readout == "mean":
self.readout = layers.MeanReadout()
elif readout == "max":
self.readout = layers.MaxReadout()
else:
raise ValueError("Unknown readout `%s`" % readout)

Expand Down Expand Up @@ -88,4 +89,4 @@ def forward(self, graph, input, all_loss=None, metric=None):
return {
"graph_feature": graph_feature,
"node_feature": node_feature
}
}
4 changes: 3 additions & 1 deletion torchdrug/models/neuralfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class NeuralFingerprint(nn.Module, core.Configurable):
batch_norm (bool, optional): apply batch normalization or not
activation (str or function, optional): activation function
concat_hidden (bool, optional): concat hidden representations from all layers as output
readout (str, optional): readout function. Available functions are ``sum`` and ``mean``.
readout: readout function. Available functions are ``sum``, ``mean``, and ``max``.
"""

def __init__(self, input_dim, output_dim, hidden_dims, edge_input_dim=None, short_cut=False, batch_norm=False,
Expand All @@ -51,6 +51,8 @@ def __init__(self, input_dim, output_dim, hidden_dims, edge_input_dim=None, shor
self.readout = layers.SumReadout()
elif readout == "mean":
self.readout = layers.MeanReadout()
elif readout == "max":
self.readout = layers.MaxReadout()
else:
raise ValueError("Unknown readout `%s`" % readout)

Expand Down
12 changes: 10 additions & 2 deletions torchdrug/models/schnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ class SchNet(nn.Module, core.Configurable):
batch_norm (bool, optional): apply batch normalization or not
activation (str or function, optional): activation function
concat_hidden (bool, optional): concat hidden representations from all layers as output
readout: readout function. Available functions are ``sum``, ``mean``, and ``max``.
"""

def __init__(self, input_dim, hidden_dims, edge_input_dim=None, cutoff=5, num_gaussian=100, short_cut=True,
batch_norm=False, activation="shifted_softplus", concat_hidden=False):
batch_norm=False, activation="shifted_softplus", concat_hidden=False, readout="sum"):
super(SchNet, self).__init__()

if not isinstance(hidden_dims, Sequence):
Expand All @@ -44,7 +45,14 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, cutoff=5, num_ga
self.layers.append(layers.ContinuousFilterConv(self.dims[i], self.dims[i + 1], edge_input_dim, None, cutoff,
num_gaussian, batch_norm, activation))

self.readout = layers.SumReadout()
if readout == "sum":
self.readout = layers.SumReadout()
elif readout == "mean":
self.readout = layers.MeanReadout()
elif readout == "max":
self.readout = layers.MaxReadout()
else:
raise ValueError("Unknown readout `%s`" % readout)

def forward(self, graph, input, all_loss=None, metric=None):
"""
Expand Down
4 changes: 4 additions & 0 deletions torchdrug/tasks/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ class ContextPrediction(tasks.Task, core.Configurable):
r2 (int, optional): outer radius for context graphs
readout (nn.Module, optional): readout function over context anchor nodes
num_negative (int, optional): number of negative samples per positive sample
readout: readout function. Available functions are ``sum``, ``mean``, and ``max``.
"""

def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout="mean", num_negative=1):
Expand All @@ -184,10 +185,13 @@ def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout="mean", n
self.context_model = copy.deepcopy(model)
else:
self.context_model = context_model

if readout == "sum":
self.readout = layers.SumReadout()
elif readout == "mean":
self.readout = layers.MeanReadout()
elif readout == "max":
self.readout = layers.MaxReadout()
else:
raise ValueError("Unknown readout `%s`" % readout)

Expand Down