Skip to content

Latest commit

 

History

History
206 lines (155 loc) · 10.9 KB

01-gnn-basics.md

File metadata and controls

206 lines (155 loc) · 10.9 KB

Graphs

  • Data Structure consisting of nodes (also known as Vertices) and edges, G = (V, E)
  • Can represent real-world objects, examples:
    • words in a document,
    • documents in a citation network,
    • people and organizations in a social media network,
    • atoms in molecular structure
  • Matrix Representation: Adjacency Matrix A
    • For graph with n nodes, A has shape (n, n)
    • If nodes i and j are connected, A(i, j) and A(j, i) = 1
    • Minor variations for directed graphs (A(i, j) != A(j, i)) and weighted graphs (A(i, j) = w)
  • Graphs can optionally have node features X
    • For graph with n nodes and feature vector of size f, X has shape (n, f)
  • Graphs can optionally also have edge features

Machine Learning Models

  • Goal: Learn a mapping F from an input space X to an output space y
  • Hypothesize some model M with random weights θ
  • Formulate the task as an optimization problem

  • Use gradient descent to update the model weights until convergence

  • Test fitted model for accuracy on new data, try a different model M if needed

Graph Models for Machine Learning

  • ML and DL tools are optimized for simple structures
    • Convolutional Neural Networks
      • images
      • regular lattice structures
    • Recurrent Neural Networks
      • text and sequence data
      • time series data
  • Problems with graphs
    • Topological complexity
    • Indeterminate size
    • Not permutation invariant
    • Instances not independent of each other

Extending Convolutions to Graphs


Source: CS-224W slide 06-GNN-1.pdf


Deep Learning architecture for images

  • Multiple layers of convolution + non-linearity + pooling
  • Fully connected layer(s) with non-linearity converts feature map to output prediction
  • Loss function Cross Entropy for classification, Mean Squared Error for regression
  • Uses gradient descent to optimize loss function


Deep Learning architecture for graphs

  • Each graph convolution layer corresponds to aggregating 1-hop neighborhood info for each node
  • In a GNN with k convolution layers, each node has information about nodes k-hops away
  • Value of k dictated by application, usually small (unlike giant CNNs)


Source: CS-224W slide 06-GNN-1.pdf


Computation Graph

  • Aggregate information from neighbors
  • Apply neural network on aggregated information (gray boxes)
  • Each node defines computation graph based on its neighborhood, i.e. each node has its own neural network!
  • Achieved using message passing


Source: CS-224W slide 06-GNN-1.pdf


Message Passing

  • Elegant approach to handle irregularity (diversity of computation graphs) in GNN
  • Message Passing steps
    • For each node in graph, gather all neighbor node embeddings (messages)
    • Aggregate all messages via an aggregate function (such as sum or average)
    • All pooled messages are passed through update function, usually a learned neural network
  • Reference: section Passing messages between parts of the graph in A Gentle Introduction to Graph Neural Networks
  • More coverage of Message Passing in Part III of tutorial

Pytorch Geometric (PyG)

PyG (PyTorch Geometric) is a library built upon PyTorch to easily write and train Graph Neural Networks (GNNs) for a wide range of applications related to structured data.

It consists of various methods for deep learning on graphs and other irregular structures, also known as geometric deep learning, from a variety of published papers. In addition, it consists of easy-to-use mini-batch loaders for operating on many small and single giant graphs, multi GPU-support, a large number of common benchmark datasets (based on simple interfaces to create your own), and helpful transforms, both for learning on arbitrary graphs as well as on 3D meshes or point clouds.


Popular PyG Graph Layers

  • Main difference is aggregation strategy

  • Graph Convolution Network (GCN)

    • Aggregate self-features and neighbor features

  • Graph Attention Network (GAT)
    • Aggregate neighboring features with weights derived from attention mechanism (between self-features and all neighbor features)
    • Attention computed using Bahdanau model using feedforward network

  • GraphSAGE (SAmple and AggreGatE)
    • Sample a subset of neighbors instead of using all of them (for scalability)
    • Importance sampling
      • Define node neighborhood using random walks
      • Sum up importance scores generated by random walks
    • Can use MEAN, MAX or SUM as aggregate functions
  • Graph Isomorphisnm Network (GIN)
    • uses SUM aggregation because it is better than MEAN and MAX aggregation for detecting graph similarity (isomorphism)
  • Available as GCNConv, SAGEConv, GATConv and GINConv

PyG DataLoaders

  • Extension of Pytorch DataLoaders
  • DataLoader (torch_geoemtric.loader.DataLoader) -- merges Data objects from a Dataset to a mini-batch
  • Dataset (torch_geometric.data.Dataset -- wrapper creating graph datasets
  • Data (torch_geometric.data.Data -- represents a single graph in PyG, has following attributes by default.
    • data.x -- node feature matrix with shape (num_nodes, num_node_features)
    • data.edge_index -- edges in COO (coordinate) format with shape (2, num_edges)
    • data.edge_attr -- edge feature matrix with shape (num_edges, num_edge_features)
    • data.y -- target matrix with shape (num_nodes, *)
    • data.pos -- node position matrix with shape (num_nodes, num_dimensions)
  • Parallelization over mini-batch achieved by creating block diagonal adjacency matrices (defined by edge_index), concatenating feature and target matrices in the node dimension, allows handling different number of nodes and edges over examples in single batch


GNN Applications

  • Node classification
    • Supervised -- labeling items (represented as nodes in a graph) by looking at the labels of their samples.
    • Unsupervised -- use random walk based embeddings or other graph features to generate labels
  • Graph classification
    • Classify a graph into one of several categories
    • Examples -- determine if a protein is an enzyme, chemical is toxic, categorizing documents (NLP)
  • Link Prediction
    • Predicts if there is a connection between two entities in a graph
    • Example -- Infer / Predict social connections in social network graph
  • Graph clustering
    • Use GNN (without classifier head) as encoder then cluster feature maps
  • Generative Graph Models
    • Use Variational AutoEncoder (VAE) that learns to predict graph's Adjacency Matrix (like images)
    • Build graph sequentially, starting with subgraph and applying nodes and edges sequentially