Skip to content

Commit

Permalink
revert to naive branching resolution
Browse files Browse the repository at this point in the history
  • Loading branch information
KiddoZhu committed Feb 18, 2022
1 parent 50adc02 commit b62c14f
Show file tree
Hide file tree
Showing 11 changed files with 79 additions and 50 deletions.
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,4 @@ matplotlib
tqdm
networkx
ninja
jinja2
class-resolver
jinja2
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
"networkx",
"ninja",
"jinja2",
"class-resolver",
],
python_requires=">=3.7,<3.9",
classifiers=[
Expand Down
4 changes: 2 additions & 2 deletions torchdrug/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .conv import MessagePassingBase, GraphConv, GraphAttentionConv, RelationalGraphConv, GraphIsomorphismConv, \
NeuralFingerprintConv, ContinuousFilterConv, MessagePassing, ChebyshevConv
from .pool import DiffPool, MinCutPool
from .readout import MeanReadout, SumReadout, MaxReadout, Softmax, Set2Set, Sort, readout_resolver, Readout
from .readout import MeanReadout, SumReadout, MaxReadout, Softmax, Set2Set, Sort
from .flow import ConditionalFlow
from .sampler import NodeSampler, EdgeSampler
from . import distribution, functional
Expand All @@ -23,7 +23,7 @@
"MessagePassingBase", "GraphConv", "GraphAttentionConv", "RelationalGraphConv", "GraphIsomorphismConv",
"NeuralFingerprintConv", "ContinuousFilterConv", "MessagePassing", "ChebyshevConv",
"DiffPool", "MinCutPool",
"MeanReadout", "SumReadout", "MaxReadout", "Softmax", "Set2Set", "Sort", "readout_resolver", "Readout",
"MeanReadout", "SumReadout", "MaxReadout", "Softmax", "Set2Set", "Sort",
"ConditionalFlow",
"NodeSampler", "EdgeSampler",
"distribution", "functional",
Expand Down
17 changes: 3 additions & 14 deletions torchdrug/layers/readout.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
import torch
from torch import nn
from torch_scatter import scatter_mean, scatter_add, scatter_max
from class_resolver import ClassResolver


class Readout(nn.Module):
"""A base class for readouts."""


class MeanReadout(Readout):
class MeanReadout(nn.Module):
"""Mean readout operator over graphs with variadic sizes."""

def forward(self, graph, input):
Expand All @@ -26,7 +21,7 @@ def forward(self, graph, input):
return output


class SumReadout(Readout):
class SumReadout(nn.Module):
"""Sum readout operator over graphs with variadic sizes."""

def forward(self, graph, input):
Expand All @@ -44,7 +39,7 @@ def forward(self, graph, input):
return output


class MaxReadout(Readout):
class MaxReadout(nn.Module):
"""Max readout operator over graphs with variadic sizes."""

def forward(self, graph, input):
Expand All @@ -62,12 +57,6 @@ def forward(self, graph, input):
return output


readout_resolver = ClassResolver.from_subclasses(
Readout,
default=SumReadout,
)


class Softmax(nn.Module):
"""Softmax operator over graphs with variadic sizes."""

Expand Down
13 changes: 9 additions & 4 deletions torchdrug/models/chebnet.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from collections.abc import Sequence

import torch
from class_resolver import Hint
from torch import nn

from torchdrug import core, layers
from torchdrug.core import Registry as R
from torchdrug.layers import Readout, readout_resolver


@R.register("models.ChebNet")
Expand All @@ -31,7 +29,7 @@ class ChebyshevConvolutionalNetwork(nn.Module, core.Configurable):
"""

def __init__(self, input_dim, hidden_dims, edge_input_dim=None, k=1, short_cut=False, batch_norm=False,
activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"):
activation="relu", concat_hidden=False, readout="sum"):
super(ChebyshevConvolutionalNetwork, self).__init__()

if not isinstance(hidden_dims, Sequence):
Expand All @@ -47,7 +45,14 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, k=1, short_cut=F
self.layers.append(layers.ChebyshevConv(self.dims[i], self.dims[i + 1], edge_input_dim, k,
batch_norm, activation))

self.readout = readout_resolver.make(readout)
if readout == "sum":
self.readout = layers.SumReadout()
elif readout == "mean":
self.readout = layers.MeanReadout()
elif readout == "max":
self.readout = layers.MaxReadout()
else:
raise ValueError("Unknown readout `%s`" % readout)

def forward(self, graph, input, all_loss=None, metric=None):
"""
Expand Down
13 changes: 9 additions & 4 deletions torchdrug/models/gat.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from collections.abc import Sequence

import torch
from class_resolver import Hint
from torch import nn

from torchdrug import core, layers
from torchdrug.core import Registry as R
from torchdrug.layers import Readout, readout_resolver


@R.register("models.GAT")
Expand All @@ -31,7 +29,7 @@ class GraphAttentionNetwork(nn.Module, core.Configurable):
"""

def __init__(self, input_dim, hidden_dims, edge_input_dim=None, num_head=1, negative_slope=0.2, short_cut=False,
batch_norm=False, activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"):
batch_norm=False, activation="relu", concat_hidden=False, readout="sum"):
super(GraphAttentionNetwork, self).__init__()

if not isinstance(hidden_dims, Sequence):
Expand All @@ -47,7 +45,14 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, num_head=1, nega
self.layers.append(layers.GraphAttentionConv(self.dims[i], self.dims[i + 1], edge_input_dim, num_head,
negative_slope, batch_norm, activation))

self.readout = readout_resolver.make(readout)
if readout == "sum":
self.readout = layers.SumReadout()
elif readout == "mean":
self.readout = layers.MeanReadout()
elif readout == "max":
self.readout = layers.MaxReadout()
else:
raise ValueError("Unknown readout `%s`" % readout)

def forward(self, graph, input, all_loss=None, metric=None):
"""
Expand Down
24 changes: 18 additions & 6 deletions torchdrug/models/gcn.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from collections.abc import Sequence

import torch
from class_resolver import Hint
from torch import nn

from torchdrug import core, layers
from torchdrug.core import Registry as R
from torchdrug.layers import Readout, readout_resolver


@R.register("models.GCN")
Expand All @@ -29,7 +27,7 @@ class GraphConvolutionalNetwork(nn.Module, core.Configurable):
"""

def __init__(self, input_dim, hidden_dims, edge_input_dim=None, short_cut=False, batch_norm=False,
activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"):
activation="relu", concat_hidden=False, readout="sum"):
super(GraphConvolutionalNetwork, self).__init__()

if not isinstance(hidden_dims, Sequence):
Expand All @@ -44,7 +42,14 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, short_cut=False,
for i in range(len(self.dims) - 1):
self.layers.append(layers.GraphConv(self.dims[i], self.dims[i + 1], edge_input_dim, batch_norm, activation))

self.readout = readout_resolver.make(readout)
if readout == "sum":
self.readout = layers.SumReadout()
elif readout == "mean":
self.readout = layers.MeanReadout()
elif readout == "max":
self.readout = layers.MaxReadout()
else:
raise ValueError("Unknown readout `%s`" % readout)

def forward(self, graph, input, all_loss=None, metric=None):
"""
Expand Down Expand Up @@ -103,7 +108,7 @@ class RelationalGraphConvolutionalNetwork(nn.Module, core.Configurable):
"""

def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, short_cut=False, batch_norm=False,
activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"):
activation="relu", concat_hidden=False, readout="sum"):
super(RelationalGraphConvolutionalNetwork, self).__init__()

if not isinstance(hidden_dims, Sequence):
Expand All @@ -120,7 +125,14 @@ def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, sh
self.layers.append(layers.RelationalGraphConv(self.dims[i], self.dims[i + 1], num_relation, edge_input_dim,
batch_norm, activation))

self.readout = readout_resolver.make(readout)
if readout == "sum":
self.readout = layers.SumReadout()
elif readout == "mean":
self.readout = layers.MeanReadout()
elif readout == "max":
self.readout = layers.MaxReadout()
else:
raise ValueError("Unknown readout `%s`" % readout)

def forward(self, graph, input, all_loss=None, metric=None):
"""
Expand Down
14 changes: 9 additions & 5 deletions torchdrug/models/gin.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from collections.abc import Sequence

import torch
from class_resolver import Hint
from torch import nn

from torchdrug import core, layers
from torchdrug.core import Registry as R
from torchdrug.layers import Readout, readout_resolver


@R.register("models.GIN")
Expand All @@ -32,8 +30,7 @@ class GraphIsomorphismNetwork(nn.Module, core.Configurable):
"""

def __init__(self, input_dim=None, hidden_dims=None, edge_input_dim=None, num_mlp_layer=2, eps=0, learn_eps=False,
short_cut=False, batch_norm=False, activation="relu", concat_hidden=False,
readout: Hint[Readout] = "sum"):
short_cut=False, batch_norm=False, activation="relu", concat_hidden=False, readout="sum"):
super(GraphIsomorphismNetwork, self).__init__()

if not isinstance(hidden_dims, Sequence):
Expand All @@ -50,7 +47,14 @@ def __init__(self, input_dim=None, hidden_dims=None, edge_input_dim=None, num_ml
self.layers.append(layers.GraphIsomorphismConv(self.dims[i], self.dims[i + 1], edge_input_dim,
layer_hidden_dims, eps, learn_eps, batch_norm, activation))

self.readout = readout_resolver.make(readout)
if readout == "sum":
self.readout = layers.SumReadout()
elif readout == "mean":
self.readout = layers.MeanReadout()
elif readout == "max":
self.readout = layers.MaxReadout()
else:
raise ValueError("Unknown readout `%s`" % readout)

def forward(self, graph, input, all_loss=None, metric=None):
"""
Expand Down
13 changes: 9 additions & 4 deletions torchdrug/models/neuralfp.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from collections.abc import Sequence

import torch
from class_resolver import Hint
from torch import nn
from torch.nn import functional as F

from torchdrug import core, layers
from torchdrug.core import Registry as R
from torchdrug.layers import Readout, readout_resolver


@R.register("models.NeuralFP")
Expand All @@ -31,7 +29,7 @@ class NeuralFingerprint(nn.Module, core.Configurable):
"""

def __init__(self, input_dim, output_dim, hidden_dims, edge_input_dim=None, short_cut=False, batch_norm=False,
activation="relu", concat_hidden=False, readout: Hint[Readout] = "sum"):
activation="relu", concat_hidden=False, readout="sum"):
super(NeuralFingerprint, self).__init__()

if not isinstance(hidden_dims, Sequence):
Expand All @@ -49,7 +47,14 @@ def __init__(self, input_dim, output_dim, hidden_dims, edge_input_dim=None, shor
batch_norm, activation))
self.linears.append(nn.Linear(self.dims[i + 1], output_dim))

self.readout = readout_resolver.make(readout)
if readout == "sum":
self.readout = layers.SumReadout()
elif readout == "mean":
self.readout = layers.MeanReadout()
elif readout == "max":
self.readout = layers.MaxReadout()
else:
raise ValueError("Unknown readout `%s`" % readout)

def forward(self, graph, input, all_loss=None, metric=None):
"""
Expand Down
13 changes: 9 additions & 4 deletions torchdrug/models/schnet.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from collections.abc import Sequence

import torch
from class_resolver import Hint
from torch import nn

from torchdrug import core, layers
from torchdrug.core import Registry as R
from torchdrug.layers import Readout, readout_resolver


@R.register("models.SchNet")
Expand All @@ -31,7 +29,7 @@ class SchNet(nn.Module, core.Configurable):
"""

def __init__(self, input_dim, hidden_dims, edge_input_dim=None, cutoff=5, num_gaussian=100, short_cut=True,
batch_norm=False, activation="shifted_softplus", concat_hidden=False, readout: Hint[Readout] = "sum"):
batch_norm=False, activation="shifted_softplus", concat_hidden=False, readout="sum"):
super(SchNet, self).__init__()

if not isinstance(hidden_dims, Sequence):
Expand All @@ -47,7 +45,14 @@ def __init__(self, input_dim, hidden_dims, edge_input_dim=None, cutoff=5, num_ga
self.layers.append(layers.ContinuousFilterConv(self.dims[i], self.dims[i + 1], edge_input_dim, None, cutoff,
num_gaussian, batch_norm, activation))

self.readout = readout_resolver.make(readout)
if readout == "sum":
self.readout = layers.SumReadout()
elif readout == "mean":
self.readout = layers.MeanReadout()
elif readout == "max":
self.readout = layers.MaxReadout()
else:
raise ValueError("Unknown readout `%s`" % readout)

def forward(self, graph, input, all_loss=None, metric=None):
"""
Expand Down
14 changes: 10 additions & 4 deletions torchdrug/tasks/pretrain.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import copy

import torch
from class_resolver import Hint
from torch import nn
from torch.nn import functional as F
from torch_scatter import scatter_max, scatter_min

from torchdrug import core, tasks, layers
from torchdrug.data import constant
from torchdrug.layers import functional, readout_resolver, Readout
from torchdrug.layers import functional
from torchdrug.core import Registry as R


Expand Down Expand Up @@ -173,7 +172,7 @@ class ContextPrediction(tasks.Task, core.Configurable):
readout: readout function. Available functions are ``sum``, ``mean``, and ``max``.
"""

def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout: Hint[Readout] = "mean", num_negative=1):
def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout="mean", num_negative=1):
super(ContextPrediction, self).__init__()
self.model = model
self.k = k
Expand All @@ -187,7 +186,14 @@ def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout: Hint[Rea
else:
self.context_model = context_model

self.readout = readout_resolver.make(readout)
if readout == "sum":
self.readout = layers.SumReadout()
elif readout == "mean":
self.readout = layers.MeanReadout()
elif readout == "max":
self.readout = layers.MaxReadout()
else:
raise ValueError("Unknown readout `%s`" % readout)

def substruct_and_context(self, graph):
center_index = (torch.rand(len(graph), device=self.device) * graph.num_nodes).long()
Expand Down

0 comments on commit b62c14f

Please sign in to comment.