From ca856c8264e932b9619a1a80788f817d7158f693 Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Tue, 1 Feb 2022 17:24:13 +0100 Subject: [PATCH 1/6] Enable max readouts --- torchdrug/models/gat.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchdrug/models/gat.py b/torchdrug/models/gat.py index 5aca144f..e4da32a8 100644 --- a/torchdrug/models/gat.py +++ b/torchdrug/models/gat.py @@ -25,7 +25,7 @@ class GraphAttentionNetwork(nn.Module, core.Configurable): batch_norm (bool, optional): apply batch normalization or not activation (str or function, optional): activation function concat_hidden (bool, optional): concat hidden representations from all layers as output - readout (str, optional): readout function. Available functions are ``sum`` and ``mean``. + readout (str, optional): readout function. Available functions are ``sum``, ``mean``, and ``max``. """ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, num_head=1, negative_slope=0.2, short_cut=False, @@ -49,6 +49,8 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, num_head=1, nega self.readout = layers.SumReadout() elif readout == "mean": self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() else: raise ValueError("Unknown readout `%s`" % readout) @@ -85,4 +87,4 @@ def forward(self, graph, input, all_loss=None, metric=None): return { "graph_feature": graph_feature, "node_feature": node_feature - } \ No newline at end of file + } From 2c5d0f79301f05e516ef52d077a7572baff9411b Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Tue, 1 Feb 2022 17:28:44 +0100 Subject: [PATCH 2/6] Update gcn.py --- torchdrug/models/gcn.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/torchdrug/models/gcn.py b/torchdrug/models/gcn.py index a4fd3b96..810119fa 100644 --- a/torchdrug/models/gcn.py +++ b/torchdrug/models/gcn.py @@ -23,7 +23,7 @@ class GraphConvolutionalNetwork(nn.Module, core.Configurable): batch_norm (bool, optional): apply batch normalization or not activation (str or function, optional): activation function concat_hidden (bool, optional): concat hidden representations from all layers as output - readout (str, optional): readout function. Available functions are ``sum`` and ``mean``. + readout (str, optional): readout function. Available functions are ``sum``, ``mean``, and ``max``. """ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, short_cut=False, batch_norm=False, @@ -46,6 +46,8 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, short_cut=False, self.readout = layers.SumReadout() elif readout == "mean": self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() else: raise ValueError("Unknown readout `%s`" % readout) @@ -102,7 +104,7 @@ class RelationalGraphConvolutionalNetwork(nn.Module, core.Configurable): batch_norm (bool, optional): apply batch normalization or not activation (str or function, optional): activation function concat_hidden (bool, optional): concat hidden representations from all layers as output - readout (str, optional): readout function. Available functions are ``sum`` and ``mean``. + readout (str, optional): readout function. Available functions are ``sum``, ``mean``, and ``max``. """ def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, short_cut=False, batch_norm=False, @@ -127,6 +129,8 @@ def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, sh self.readout = layers.SumReadout() elif readout == "mean": self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() else: raise ValueError("Unknown readout `%s`" % readout) @@ -165,4 +169,4 @@ def forward(self, graph, input, all_loss=None, metric=None): return { "graph_feature": graph_feature, "node_feature": node_feature - } \ No newline at end of file + } From d6b544dfda91b0c504252d3bda21c04bd0b3cfa8 Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Tue, 1 Feb 2022 17:32:40 +0100 Subject: [PATCH 3/6] Update gin.py --- torchdrug/models/gin.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchdrug/models/gin.py b/torchdrug/models/gin.py index 70e217d2..ed95cb7d 100644 --- a/torchdrug/models/gin.py +++ b/torchdrug/models/gin.py @@ -26,7 +26,7 @@ class GraphIsomorphismNetwork(nn.Module, core.Configurable): batch_norm (bool, optional): apply batch normalization or not activation (str or function, optional): activation function concat_hidden (bool, optional): concat hidden representations from all layers as output - readout (str, optional): readout function. Available functions are ``sum`` and ``mean``. + readout (str, optional): readout function. Available functions are ``sum``, ``mean``, and ``max``. """ def __init__(self, input_dim=None, hidden_dims=None, edge_input_dim=None, num_mlp_layer=2, eps=0, learn_eps=False, @@ -52,6 +52,8 @@ def __init__(self, input_dim=None, hidden_dims=None, edge_input_dim=None, num_ml self.readout = layers.SumReadout() elif readout == "mean": self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() else: raise ValueError("Unknown readout `%s`" % readout) @@ -88,4 +90,4 @@ def forward(self, graph, input, all_loss=None, metric=None): return { "graph_feature": graph_feature, "node_feature": node_feature - } \ No newline at end of file + } From 4d74e5d4a5261c89c56fb95dc064ec85ea53fe19 Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Tue, 8 Feb 2022 15:30:36 +0100 Subject: [PATCH 4/6] Add class-resolver --- torchdrug/layers/__init__.py | 2 +- torchdrug/layers/readout.py | 17 ++++++++++++++--- torchdrug/models/gat.py | 13 +++---------- torchdrug/models/gcn.py | 15 +++++---------- torchdrug/models/gin.py | 9 +-------- torchdrug/models/neuralfp.py | 7 +------ torchdrug/models/schnet.py | 4 ++-- 7 files changed, 27 insertions(+), 40 deletions(-) diff --git a/torchdrug/layers/__init__.py b/torchdrug/layers/__init__.py index 22c3b60b..7d79f9be 100644 --- a/torchdrug/layers/__init__.py +++ b/torchdrug/layers/__init__.py @@ -3,7 +3,7 @@ from .conv import MessagePassingBase, GraphConv, GraphAttentionConv, RelationalGraphConv, GraphIsomorphismConv, \ NeuralFingerprintConv, ContinuousFilterConv, MessagePassing, ChebyshevConv from .pool import DiffPool, MinCutPool -from .readout import MeanReadout, SumReadout, MaxReadout, Softmax, Set2Set, Sort +from .readout import MeanReadout, SumReadout, MaxReadout, Softmax, Set2Set, Sort, readout_resolver, Readout from .flow import ConditionalFlow from .sampler import NodeSampler, EdgeSampler from . import distribution, functional diff --git a/torchdrug/layers/readout.py b/torchdrug/layers/readout.py index 3680d292..66ed64ea 100644 --- a/torchdrug/layers/readout.py +++ b/torchdrug/layers/readout.py @@ -1,9 +1,14 @@ import torch from torch import nn from torch_scatter import scatter_mean, scatter_add, scatter_max +from class_resolver import ClassResolver -class MeanReadout(nn.Module): +class Readout(nn.Module): + """A base class for readouts.""" + + +class MeanReadout(Readout): """Mean readout operator over graphs with variadic sizes.""" def forward(self, graph, input): @@ -21,7 +26,7 @@ def forward(self, graph, input): return output -class SumReadout(nn.Module): +class SumReadout(Readout): """Sum readout operator over graphs with variadic sizes.""" def forward(self, graph, input): @@ -39,7 +44,7 @@ def forward(self, graph, input): return output -class MaxReadout(nn.Module): +class MaxReadout(Readout): """Max readout operator over graphs with variadic sizes.""" def forward(self, graph, input): @@ -57,6 +62,12 @@ def forward(self, graph, input): return output +readout_resolver = ClassResolver.from_subclasses( + Readout, + default=SumReadout, +) + + class Softmax(nn.Module): """Softmax operator over graphs with variadic sizes.""" diff --git a/torchdrug/models/gat.py b/torchdrug/models/gat.py index e4da32a8..211f8d0b 100644 --- a/torchdrug/models/gat.py +++ b/torchdrug/models/gat.py @@ -25,11 +25,11 @@ class GraphAttentionNetwork(nn.Module, core.Configurable): batch_norm (bool, optional): apply batch normalization or not activation (str or function, optional): activation function concat_hidden (bool, optional): concat hidden representations from all layers as output - readout (str, optional): readout function. Available functions are ``sum``, ``mean``, and ``max``. + readout: readout function. Available functions are ``sum``, ``mean``, and ``max``. """ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, num_head=1, negative_slope=0.2, short_cut=False, - batch_norm=False, activation="relu", concat_hidden=False, readout="sum"): + batch_norm=False, activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"): super(GraphAttentionNetwork, self).__init__() if not isinstance(hidden_dims, Sequence): @@ -45,14 +45,7 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, num_head=1, nega self.layers.append(layers.GraphAttentionConv(self.dims[i], self.dims[i + 1], edge_input_dim, num_head, negative_slope, batch_norm, activation)) - if readout == "sum": - self.readout = layers.SumReadout() - elif readout == "mean": - self.readout = layers.MeanReadout() - elif readout == "max": - self.readout = layers.MaxReadout() - else: - raise ValueError("Unknown readout `%s`" % readout) + self.readout = readout_resolver.make(readout) def forward(self, graph, input, all_loss=None, metric=None): """ diff --git a/torchdrug/models/gcn.py b/torchdrug/models/gcn.py index 810119fa..968f2b1f 100644 --- a/torchdrug/models/gcn.py +++ b/torchdrug/models/gcn.py @@ -1,9 +1,11 @@ from collections.abc import Sequence import torch +from class_resolver import Hint from torch import nn from torchdrug import core, layers +from torchdrug.layers import readout_resolver, Readout from torchdrug.core import Registry as R @@ -23,11 +25,11 @@ class GraphConvolutionalNetwork(nn.Module, core.Configurable): batch_norm (bool, optional): apply batch normalization or not activation (str or function, optional): activation function concat_hidden (bool, optional): concat hidden representations from all layers as output - readout (str, optional): readout function. Available functions are ``sum``, ``mean``, and ``max``. + readout: readout function. Available functions are ``sum``, ``mean``, and ``max``. """ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, short_cut=False, batch_norm=False, - activation="relu", concat_hidden=False, readout="sum"): + activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"): super(GraphConvolutionalNetwork, self).__init__() if not isinstance(hidden_dims, Sequence): @@ -42,14 +44,7 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, short_cut=False, for i in range(len(self.dims) - 1): self.layers.append(layers.GraphConv(self.dims[i], self.dims[i + 1], edge_input_dim, batch_norm, activation)) - if readout == "sum": - self.readout = layers.SumReadout() - elif readout == "mean": - self.readout = layers.MeanReadout() - elif readout == "max": - self.readout = layers.MaxReadout() - else: - raise ValueError("Unknown readout `%s`" % readout) + self.readout = readout_resolver.make(readout) def forward(self, graph, input, all_loss=None, metric=None): """ diff --git a/torchdrug/models/gin.py b/torchdrug/models/gin.py index ed95cb7d..4b945cf0 100644 --- a/torchdrug/models/gin.py +++ b/torchdrug/models/gin.py @@ -48,14 +48,7 @@ def __init__(self, input_dim=None, hidden_dims=None, edge_input_dim=None, num_ml self.layers.append(layers.GraphIsomorphismConv(self.dims[i], self.dims[i + 1], edge_input_dim, layer_hidden_dims, eps, learn_eps, batch_norm, activation)) - if readout == "sum": - self.readout = layers.SumReadout() - elif readout == "mean": - self.readout = layers.MeanReadout() - elif readout == "max": - self.readout = layers.MaxReadout() - else: - raise ValueError("Unknown readout `%s`" % readout) + self.readout = readout_resolver.make(readout) def forward(self, graph, input, all_loss=None, metric=None): """ diff --git a/torchdrug/models/neuralfp.py b/torchdrug/models/neuralfp.py index ec47c2c5..4d9a1ab6 100644 --- a/torchdrug/models/neuralfp.py +++ b/torchdrug/models/neuralfp.py @@ -47,12 +47,7 @@ def __init__(self, input_dim, output_dim, hidden_dims, edge_input_dim=None, shor batch_norm, activation)) self.linears.append(nn.Linear(self.dims[i + 1], output_dim)) - if readout == "sum": - self.readout = layers.SumReadout() - elif readout == "mean": - self.readout = layers.MeanReadout() - else: - raise ValueError("Unknown readout `%s`" % readout) + self.readout = readout_resolver.make(readout) def forward(self, graph, input, all_loss=None, metric=None): """ diff --git a/torchdrug/models/schnet.py b/torchdrug/models/schnet.py index 0bfb2cf9..3644e2ec 100644 --- a/torchdrug/models/schnet.py +++ b/torchdrug/models/schnet.py @@ -28,7 +28,7 @@ class SchNet(nn.Module, core.Configurable): """ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, cutoff=5, num_gaussian=100, short_cut=True, - batch_norm=False, activation="shifted_softplus", concat_hidden=False): + batch_norm=False, activation="shifted_softplus", concat_hidden=False, readout: Hint[Readout] = "sum"): super(SchNet, self).__init__() if not isinstance(hidden_dims, Sequence): @@ -44,7 +44,7 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, cutoff=5, num_ga self.layers.append(layers.ContinuousFilterConv(self.dims[i], self.dims[i + 1], edge_input_dim, None, cutoff, num_gaussian, batch_norm, activation)) - self.readout = layers.SumReadout() + self.readout = readout_resolver.make(readout) def forward(self, graph, input, all_loss=None, metric=None): """ From 50adc0215c98b40840ed7d8383cba1960ba48791 Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Tue, 8 Feb 2022 15:38:00 +0100 Subject: [PATCH 5/6] Cleanup --- requirements.txt | 3 ++- setup.py | 1 + torchdrug/layers/__init__.py | 2 +- torchdrug/models/chebnet.py | 13 +++++-------- torchdrug/models/gat.py | 2 ++ torchdrug/models/gcn.py | 15 ++++----------- torchdrug/models/gin.py | 6 ++++-- torchdrug/models/neuralfp.py | 6 ++++-- torchdrug/models/schnet.py | 3 +++ torchdrug/tasks/pretrain.py | 14 ++++++-------- 10 files changed, 32 insertions(+), 33 deletions(-) diff --git a/requirements.txt b/requirements.txt index 683eb652..df9a2378 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,5 @@ matplotlib tqdm networkx ninja -jinja2 \ No newline at end of file +jinja2 +class-resolver diff --git a/setup.py b/setup.py index a9243870..21356d3a 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,7 @@ "networkx", "ninja", "jinja2", + "class-resolver", ], python_requires=">=3.7,<3.9", classifiers=[ diff --git a/torchdrug/layers/__init__.py b/torchdrug/layers/__init__.py index 7d79f9be..ea184887 100644 --- a/torchdrug/layers/__init__.py +++ b/torchdrug/layers/__init__.py @@ -23,7 +23,7 @@ "MessagePassingBase", "GraphConv", "GraphAttentionConv", "RelationalGraphConv", "GraphIsomorphismConv", "NeuralFingerprintConv", "ContinuousFilterConv", "MessagePassing", "ChebyshevConv", "DiffPool", "MinCutPool", - "MeanReadout", "SumReadout", "MaxReadout", "Softmax", "Set2Set", "Sort", + "MeanReadout", "SumReadout", "MaxReadout", "Softmax", "Set2Set", "Sort", "readout_resolver", "Readout", "ConditionalFlow", "NodeSampler", "EdgeSampler", "distribution", "functional", diff --git a/torchdrug/models/chebnet.py b/torchdrug/models/chebnet.py index cb1793d8..521aaef2 100644 --- a/torchdrug/models/chebnet.py +++ b/torchdrug/models/chebnet.py @@ -1,10 +1,12 @@ from collections.abc import Sequence import torch +from class_resolver import Hint from torch import nn from torchdrug import core, layers from torchdrug.core import Registry as R +from torchdrug.layers import Readout, readout_resolver @R.register("models.ChebNet") @@ -25,11 +27,11 @@ class ChebyshevConvolutionalNetwork(nn.Module, core.Configurable): batch_norm (bool, optional): apply batch normalization or not activation (str or function, optional): activation function concat_hidden (bool, optional): concat hidden representations from all layers as output - readout (str, optional): readout function. Available functions are ``sum`` and ``mean``. + readout: readout function. Available functions are ``sum`` and ``mean``. """ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, k=1, short_cut=False, batch_norm=False, - activation="relu", concat_hidden=False, readout="sum"): + activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"): super(ChebyshevConvolutionalNetwork, self).__init__() if not isinstance(hidden_dims, Sequence): @@ -45,12 +47,7 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, k=1, short_cut=F self.layers.append(layers.ChebyshevConv(self.dims[i], self.dims[i + 1], edge_input_dim, k, batch_norm, activation)) - if readout == "sum": - self.readout = layers.SumReadout() - elif readout == "mean": - self.readout = layers.MeanReadout() - else: - raise ValueError("Unknown readout `%s`" % readout) + self.readout = readout_resolver.make(readout) def forward(self, graph, input, all_loss=None, metric=None): """ diff --git a/torchdrug/models/gat.py b/torchdrug/models/gat.py index 211f8d0b..db2b5e72 100644 --- a/torchdrug/models/gat.py +++ b/torchdrug/models/gat.py @@ -1,10 +1,12 @@ from collections.abc import Sequence import torch +from class_resolver import Hint from torch import nn from torchdrug import core, layers from torchdrug.core import Registry as R +from torchdrug.layers import Readout, readout_resolver @R.register("models.GAT") diff --git a/torchdrug/models/gcn.py b/torchdrug/models/gcn.py index 968f2b1f..679ae28e 100644 --- a/torchdrug/models/gcn.py +++ b/torchdrug/models/gcn.py @@ -5,8 +5,8 @@ from torch import nn from torchdrug import core, layers -from torchdrug.layers import readout_resolver, Readout from torchdrug.core import Registry as R +from torchdrug.layers import Readout, readout_resolver @R.register("models.GCN") @@ -99,11 +99,11 @@ class RelationalGraphConvolutionalNetwork(nn.Module, core.Configurable): batch_norm (bool, optional): apply batch normalization or not activation (str or function, optional): activation function concat_hidden (bool, optional): concat hidden representations from all layers as output - readout (str, optional): readout function. Available functions are ``sum``, ``mean``, and ``max``. + readout: readout function. Available functions are ``sum``, ``mean``, and ``max``. """ def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, short_cut=False, batch_norm=False, - activation="relu", concat_hidden=False, readout="sum"): + activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"): super(RelationalGraphConvolutionalNetwork, self).__init__() if not isinstance(hidden_dims, Sequence): @@ -120,14 +120,7 @@ def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, sh self.layers.append(layers.RelationalGraphConv(self.dims[i], self.dims[i + 1], num_relation, edge_input_dim, batch_norm, activation)) - if readout == "sum": - self.readout = layers.SumReadout() - elif readout == "mean": - self.readout = layers.MeanReadout() - elif readout == "max": - self.readout = layers.MaxReadout() - else: - raise ValueError("Unknown readout `%s`" % readout) + self.readout = readout_resolver.make(readout) def forward(self, graph, input, all_loss=None, metric=None): """ diff --git a/torchdrug/models/gin.py b/torchdrug/models/gin.py index 4b945cf0..f32e2cae 100644 --- a/torchdrug/models/gin.py +++ b/torchdrug/models/gin.py @@ -1,10 +1,12 @@ from collections.abc import Sequence import torch +from class_resolver import Hint from torch import nn from torchdrug import core, layers from torchdrug.core import Registry as R +from torchdrug.layers import Readout, readout_resolver @R.register("models.GIN") @@ -26,12 +28,12 @@ class GraphIsomorphismNetwork(nn.Module, core.Configurable): batch_norm (bool, optional): apply batch normalization or not activation (str or function, optional): activation function concat_hidden (bool, optional): concat hidden representations from all layers as output - readout (str, optional): readout function. Available functions are ``sum``, ``mean``, and ``max``. + readout: readout function. Available functions are ``sum``, ``mean``, and ``max``. """ def __init__(self, input_dim=None, hidden_dims=None, edge_input_dim=None, num_mlp_layer=2, eps=0, learn_eps=False, short_cut=False, batch_norm=False, activation="relu", concat_hidden=False, - readout="sum"): + readout: Hint[Readout] = "sum"): super(GraphIsomorphismNetwork, self).__init__() if not isinstance(hidden_dims, Sequence): diff --git a/torchdrug/models/neuralfp.py b/torchdrug/models/neuralfp.py index 4d9a1ab6..a1f95129 100644 --- a/torchdrug/models/neuralfp.py +++ b/torchdrug/models/neuralfp.py @@ -1,11 +1,13 @@ from collections.abc import Sequence import torch +from class_resolver import Hint from torch import nn from torch.nn import functional as F from torchdrug import core, layers from torchdrug.core import Registry as R +from torchdrug.layers import Readout, readout_resolver @R.register("models.NeuralFP") @@ -25,11 +27,11 @@ class NeuralFingerprint(nn.Module, core.Configurable): batch_norm (bool, optional): apply batch normalization or not activation (str or function, optional): activation function concat_hidden (bool, optional): concat hidden representations from all layers as output - readout (str, optional): readout function. Available functions are ``sum`` and ``mean``. + readout: readout function. Available functions are ``sum``, ``mean``, and ``max``. """ def __init__(self, input_dim, output_dim, hidden_dims, edge_input_dim=None, short_cut=False, batch_norm=False, - activation="relu", concat_hidden=False, readout="sum"): + activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"): super(NeuralFingerprint, self).__init__() if not isinstance(hidden_dims, Sequence): diff --git a/torchdrug/models/schnet.py b/torchdrug/models/schnet.py index 3644e2ec..7861a619 100644 --- a/torchdrug/models/schnet.py +++ b/torchdrug/models/schnet.py @@ -1,10 +1,12 @@ from collections.abc import Sequence import torch +from class_resolver import Hint from torch import nn from torchdrug import core, layers from torchdrug.core import Registry as R +from torchdrug.layers import Readout, readout_resolver @R.register("models.SchNet") @@ -25,6 +27,7 @@ class SchNet(nn.Module, core.Configurable): batch_norm (bool, optional): apply batch normalization or not activation (str or function, optional): activation function concat_hidden (bool, optional): concat hidden representations from all layers as output + readout: readout function. Available functions are ``sum``, ``mean``, and ``max``. """ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, cutoff=5, num_gaussian=100, short_cut=True, diff --git a/torchdrug/tasks/pretrain.py b/torchdrug/tasks/pretrain.py index 5e8deecd..8b9460e8 100644 --- a/torchdrug/tasks/pretrain.py +++ b/torchdrug/tasks/pretrain.py @@ -1,13 +1,14 @@ import copy import torch +from class_resolver import Hint from torch import nn from torch.nn import functional as F from torch_scatter import scatter_max, scatter_min from torchdrug import core, tasks, layers from torchdrug.data import constant -from torchdrug.layers import functional +from torchdrug.layers import functional, readout_resolver, Readout from torchdrug.core import Registry as R @@ -169,9 +170,10 @@ class ContextPrediction(tasks.Task, core.Configurable): r2 (int, optional): outer radius for context graphs readout (nn.Module, optional): readout function over context anchor nodes num_negative (int, optional): number of negative samples per positive sample + readout: readout function. Available functions are ``sum``, ``mean``, and ``max``. """ - def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout="mean", num_negative=1): + def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout: Hint[Readout] = "mean", num_negative=1): super(ContextPrediction, self).__init__() self.model = model self.k = k @@ -184,12 +186,8 @@ def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout="mean", n self.context_model = copy.deepcopy(model) else: self.context_model = context_model - if readout == "sum": - self.readout = layers.SumReadout() - elif readout == "mean": - self.readout = layers.MeanReadout() - else: - raise ValueError("Unknown readout `%s`" % readout) + + self.readout = readout_resolver.make(readout) def substruct_and_context(self, graph): center_index = (torch.rand(len(graph), device=self.device) * graph.num_nodes).long() From b62c14f6b67d9024db4d35d6fb53493ce5d71563 Mon Sep 17 00:00:00 2001 From: Zhaocheng Zhu Date: Fri, 18 Feb 2022 11:30:55 -0500 Subject: [PATCH 6/6] revert to naive branching resolution --- requirements.txt | 3 +-- setup.py | 1 - torchdrug/layers/__init__.py | 4 ++-- torchdrug/layers/readout.py | 17 +++-------------- torchdrug/models/chebnet.py | 13 +++++++++---- torchdrug/models/gat.py | 13 +++++++++---- torchdrug/models/gcn.py | 24 ++++++++++++++++++------ torchdrug/models/gin.py | 14 +++++++++----- torchdrug/models/neuralfp.py | 13 +++++++++---- torchdrug/models/schnet.py | 13 +++++++++---- torchdrug/tasks/pretrain.py | 14 ++++++++++---- 11 files changed, 79 insertions(+), 50 deletions(-) diff --git a/requirements.txt b/requirements.txt index df9a2378..683eb652 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,5 +7,4 @@ matplotlib tqdm networkx ninja -jinja2 -class-resolver +jinja2 \ No newline at end of file diff --git a/setup.py b/setup.py index 21356d3a..a9243870 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,6 @@ "networkx", "ninja", "jinja2", - "class-resolver", ], python_requires=">=3.7,<3.9", classifiers=[ diff --git a/torchdrug/layers/__init__.py b/torchdrug/layers/__init__.py index ea184887..22c3b60b 100644 --- a/torchdrug/layers/__init__.py +++ b/torchdrug/layers/__init__.py @@ -3,7 +3,7 @@ from .conv import MessagePassingBase, GraphConv, GraphAttentionConv, RelationalGraphConv, GraphIsomorphismConv, \ NeuralFingerprintConv, ContinuousFilterConv, MessagePassing, ChebyshevConv from .pool import DiffPool, MinCutPool -from .readout import MeanReadout, SumReadout, MaxReadout, Softmax, Set2Set, Sort, readout_resolver, Readout +from .readout import MeanReadout, SumReadout, MaxReadout, Softmax, Set2Set, Sort from .flow import ConditionalFlow from .sampler import NodeSampler, EdgeSampler from . import distribution, functional @@ -23,7 +23,7 @@ "MessagePassingBase", "GraphConv", "GraphAttentionConv", "RelationalGraphConv", "GraphIsomorphismConv", "NeuralFingerprintConv", "ContinuousFilterConv", "MessagePassing", "ChebyshevConv", "DiffPool", "MinCutPool", - "MeanReadout", "SumReadout", "MaxReadout", "Softmax", "Set2Set", "Sort", "readout_resolver", "Readout", + "MeanReadout", "SumReadout", "MaxReadout", "Softmax", "Set2Set", "Sort", "ConditionalFlow", "NodeSampler", "EdgeSampler", "distribution", "functional", diff --git a/torchdrug/layers/readout.py b/torchdrug/layers/readout.py index 66ed64ea..3680d292 100644 --- a/torchdrug/layers/readout.py +++ b/torchdrug/layers/readout.py @@ -1,14 +1,9 @@ import torch from torch import nn from torch_scatter import scatter_mean, scatter_add, scatter_max -from class_resolver import ClassResolver -class Readout(nn.Module): - """A base class for readouts.""" - - -class MeanReadout(Readout): +class MeanReadout(nn.Module): """Mean readout operator over graphs with variadic sizes.""" def forward(self, graph, input): @@ -26,7 +21,7 @@ def forward(self, graph, input): return output -class SumReadout(Readout): +class SumReadout(nn.Module): """Sum readout operator over graphs with variadic sizes.""" def forward(self, graph, input): @@ -44,7 +39,7 @@ def forward(self, graph, input): return output -class MaxReadout(Readout): +class MaxReadout(nn.Module): """Max readout operator over graphs with variadic sizes.""" def forward(self, graph, input): @@ -62,12 +57,6 @@ def forward(self, graph, input): return output -readout_resolver = ClassResolver.from_subclasses( - Readout, - default=SumReadout, -) - - class Softmax(nn.Module): """Softmax operator over graphs with variadic sizes.""" diff --git a/torchdrug/models/chebnet.py b/torchdrug/models/chebnet.py index 521aaef2..86d2aeff 100644 --- a/torchdrug/models/chebnet.py +++ b/torchdrug/models/chebnet.py @@ -1,12 +1,10 @@ from collections.abc import Sequence import torch -from class_resolver import Hint from torch import nn from torchdrug import core, layers from torchdrug.core import Registry as R -from torchdrug.layers import Readout, readout_resolver @R.register("models.ChebNet") @@ -31,7 +29,7 @@ class ChebyshevConvolutionalNetwork(nn.Module, core.Configurable): """ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, k=1, short_cut=False, batch_norm=False, - activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"): + activation="relu", concat_hidden=False, readout="sum"): super(ChebyshevConvolutionalNetwork, self).__init__() if not isinstance(hidden_dims, Sequence): @@ -47,7 +45,14 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, k=1, short_cut=F self.layers.append(layers.ChebyshevConv(self.dims[i], self.dims[i + 1], edge_input_dim, k, batch_norm, activation)) - self.readout = readout_resolver.make(readout) + if readout == "sum": + self.readout = layers.SumReadout() + elif readout == "mean": + self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() + else: + raise ValueError("Unknown readout `%s`" % readout) def forward(self, graph, input, all_loss=None, metric=None): """ diff --git a/torchdrug/models/gat.py b/torchdrug/models/gat.py index db2b5e72..c22e848c 100644 --- a/torchdrug/models/gat.py +++ b/torchdrug/models/gat.py @@ -1,12 +1,10 @@ from collections.abc import Sequence import torch -from class_resolver import Hint from torch import nn from torchdrug import core, layers from torchdrug.core import Registry as R -from torchdrug.layers import Readout, readout_resolver @R.register("models.GAT") @@ -31,7 +29,7 @@ class GraphAttentionNetwork(nn.Module, core.Configurable): """ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, num_head=1, negative_slope=0.2, short_cut=False, - batch_norm=False, activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"): + batch_norm=False, activation="relu", concat_hidden=False, readout="sum"): super(GraphAttentionNetwork, self).__init__() if not isinstance(hidden_dims, Sequence): @@ -47,7 +45,14 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, num_head=1, nega self.layers.append(layers.GraphAttentionConv(self.dims[i], self.dims[i + 1], edge_input_dim, num_head, negative_slope, batch_norm, activation)) - self.readout = readout_resolver.make(readout) + if readout == "sum": + self.readout = layers.SumReadout() + elif readout == "mean": + self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() + else: + raise ValueError("Unknown readout `%s`" % readout) def forward(self, graph, input, all_loss=None, metric=None): """ diff --git a/torchdrug/models/gcn.py b/torchdrug/models/gcn.py index 679ae28e..352c39b3 100644 --- a/torchdrug/models/gcn.py +++ b/torchdrug/models/gcn.py @@ -1,12 +1,10 @@ from collections.abc import Sequence import torch -from class_resolver import Hint from torch import nn from torchdrug import core, layers from torchdrug.core import Registry as R -from torchdrug.layers import Readout, readout_resolver @R.register("models.GCN") @@ -29,7 +27,7 @@ class GraphConvolutionalNetwork(nn.Module, core.Configurable): """ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, short_cut=False, batch_norm=False, - activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"): + activation="relu", concat_hidden=False, readout="sum"): super(GraphConvolutionalNetwork, self).__init__() if not isinstance(hidden_dims, Sequence): @@ -44,7 +42,14 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, short_cut=False, for i in range(len(self.dims) - 1): self.layers.append(layers.GraphConv(self.dims[i], self.dims[i + 1], edge_input_dim, batch_norm, activation)) - self.readout = readout_resolver.make(readout) + if readout == "sum": + self.readout = layers.SumReadout() + elif readout == "mean": + self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() + else: + raise ValueError("Unknown readout `%s`" % readout) def forward(self, graph, input, all_loss=None, metric=None): """ @@ -103,7 +108,7 @@ class RelationalGraphConvolutionalNetwork(nn.Module, core.Configurable): """ def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, short_cut=False, batch_norm=False, - activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"): + activation="relu", concat_hidden=False, readout="sum"): super(RelationalGraphConvolutionalNetwork, self).__init__() if not isinstance(hidden_dims, Sequence): @@ -120,7 +125,14 @@ def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, sh self.layers.append(layers.RelationalGraphConv(self.dims[i], self.dims[i + 1], num_relation, edge_input_dim, batch_norm, activation)) - self.readout = readout_resolver.make(readout) + if readout == "sum": + self.readout = layers.SumReadout() + elif readout == "mean": + self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() + else: + raise ValueError("Unknown readout `%s`" % readout) def forward(self, graph, input, all_loss=None, metric=None): """ diff --git a/torchdrug/models/gin.py b/torchdrug/models/gin.py index f32e2cae..f0d99de9 100644 --- a/torchdrug/models/gin.py +++ b/torchdrug/models/gin.py @@ -1,12 +1,10 @@ from collections.abc import Sequence import torch -from class_resolver import Hint from torch import nn from torchdrug import core, layers from torchdrug.core import Registry as R -from torchdrug.layers import Readout, readout_resolver @R.register("models.GIN") @@ -32,8 +30,7 @@ class GraphIsomorphismNetwork(nn.Module, core.Configurable): """ def __init__(self, input_dim=None, hidden_dims=None, edge_input_dim=None, num_mlp_layer=2, eps=0, learn_eps=False, - short_cut=False, batch_norm=False, activation="relu", concat_hidden=False, - readout: Hint[Readout] = "sum"): + short_cut=False, batch_norm=False, activation="relu", concat_hidden=False, readout="sum"): super(GraphIsomorphismNetwork, self).__init__() if not isinstance(hidden_dims, Sequence): @@ -50,7 +47,14 @@ def __init__(self, input_dim=None, hidden_dims=None, edge_input_dim=None, num_ml self.layers.append(layers.GraphIsomorphismConv(self.dims[i], self.dims[i + 1], edge_input_dim, layer_hidden_dims, eps, learn_eps, batch_norm, activation)) - self.readout = readout_resolver.make(readout) + if readout == "sum": + self.readout = layers.SumReadout() + elif readout == "mean": + self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() + else: + raise ValueError("Unknown readout `%s`" % readout) def forward(self, graph, input, all_loss=None, metric=None): """ diff --git a/torchdrug/models/neuralfp.py b/torchdrug/models/neuralfp.py index a1f95129..b4e59794 100644 --- a/torchdrug/models/neuralfp.py +++ b/torchdrug/models/neuralfp.py @@ -1,13 +1,11 @@ from collections.abc import Sequence import torch -from class_resolver import Hint from torch import nn from torch.nn import functional as F from torchdrug import core, layers from torchdrug.core import Registry as R -from torchdrug.layers import Readout, readout_resolver @R.register("models.NeuralFP") @@ -31,7 +29,7 @@ class NeuralFingerprint(nn.Module, core.Configurable): """ def __init__(self, input_dim, output_dim, hidden_dims, edge_input_dim=None, short_cut=False, batch_norm=False, - activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"): + activation="relu", concat_hidden=False, readout="sum"): super(NeuralFingerprint, self).__init__() if not isinstance(hidden_dims, Sequence): @@ -49,7 +47,14 @@ def __init__(self, input_dim, output_dim, hidden_dims, edge_input_dim=None, shor batch_norm, activation)) self.linears.append(nn.Linear(self.dims[i + 1], output_dim)) - self.readout = readout_resolver.make(readout) + if readout == "sum": + self.readout = layers.SumReadout() + elif readout == "mean": + self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() + else: + raise ValueError("Unknown readout `%s`" % readout) def forward(self, graph, input, all_loss=None, metric=None): """ diff --git a/torchdrug/models/schnet.py b/torchdrug/models/schnet.py index 7861a619..1a7cf09e 100644 --- a/torchdrug/models/schnet.py +++ b/torchdrug/models/schnet.py @@ -1,12 +1,10 @@ from collections.abc import Sequence import torch -from class_resolver import Hint from torch import nn from torchdrug import core, layers from torchdrug.core import Registry as R -from torchdrug.layers import Readout, readout_resolver @R.register("models.SchNet") @@ -31,7 +29,7 @@ class SchNet(nn.Module, core.Configurable): """ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, cutoff=5, num_gaussian=100, short_cut=True, - batch_norm=False, activation="shifted_softplus", concat_hidden=False, readout: Hint[Readout] = "sum"): + batch_norm=False, activation="shifted_softplus", concat_hidden=False, readout="sum"): super(SchNet, self).__init__() if not isinstance(hidden_dims, Sequence): @@ -47,7 +45,14 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, cutoff=5, num_ga self.layers.append(layers.ContinuousFilterConv(self.dims[i], self.dims[i + 1], edge_input_dim, None, cutoff, num_gaussian, batch_norm, activation)) - self.readout = readout_resolver.make(readout) + if readout == "sum": + self.readout = layers.SumReadout() + elif readout == "mean": + self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() + else: + raise ValueError("Unknown readout `%s`" % readout) def forward(self, graph, input, all_loss=None, metric=None): """ diff --git a/torchdrug/tasks/pretrain.py b/torchdrug/tasks/pretrain.py index 8b9460e8..a26bc5a4 100644 --- a/torchdrug/tasks/pretrain.py +++ b/torchdrug/tasks/pretrain.py @@ -1,14 +1,13 @@ import copy import torch -from class_resolver import Hint from torch import nn from torch.nn import functional as F from torch_scatter import scatter_max, scatter_min from torchdrug import core, tasks, layers from torchdrug.data import constant -from torchdrug.layers import functional, readout_resolver, Readout +from torchdrug.layers import functional from torchdrug.core import Registry as R @@ -173,7 +172,7 @@ class ContextPrediction(tasks.Task, core.Configurable): readout: readout function. Available functions are ``sum``, ``mean``, and ``max``. """ - def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout: Hint[Readout] = "mean", num_negative=1): + def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout="mean", num_negative=1): super(ContextPrediction, self).__init__() self.model = model self.k = k @@ -187,7 +186,14 @@ def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout: Hint[Rea else: self.context_model = context_model - self.readout = readout_resolver.make(readout) + if readout == "sum": + self.readout = layers.SumReadout() + elif readout == "mean": + self.readout = layers.MeanReadout() + elif readout == "max": + self.readout = layers.MaxReadout() + else: + raise ValueError("Unknown readout `%s`" % readout) def substruct_and_context(self, graph): center_index = (torch.rand(len(graph), device=self.device) * graph.num_nodes).long()