My graph neural network journey - part 1
Graph Neural Network GNN NotesDisclaimer: This is the first post of a series; it is based on a Stanford youtube course that I will cite below in the reference section. Those posts will help me to better understand what I am studying and hopefully provide you with a high-level description of GNN and how they work. I will re-elaborate and condense what is contained in that course but for more specific notions please watch the original material and if I make some "Grezza" (roman slang word to say error/mistake) let me know through my contacts.
Graphs, graphs, graphs
Introduction
That course is based on graphs so let's briefly talk about them. It is well known that there are a lot of things that could be modeled through a graph, for example, a social network, but what is a graph? Simply a composition of nodes and edges. Now I could talk about the plethora of graph types but I won't (keep it simple).
Various tasks can be applied to graphs to reach a given goal. These tasks can be divided into 4 levels: node, edge, community, and graph. If we want to categorize a user or an item we talk about node classification, so a node-level task. Do you want to know if someone will buy that idem? LINK PREDICTION! An edge-level task. What about the new drug that they have discovered? Graph classification (graph-level) and so on...
Concerning my research interests, the most relevant task is link prediction. This task can be represented by a bipartite graph (users-items) where the edges represent the item bought/rated/viewed by a user. The main task is to predict the missing edges which are the recommendations.
How to rapresent something with a graph?
To represent a domain with a graph we have to use nodes and edges, as mentioned previously, but what if I told you to represent some phone calls or a social network? Do you think that the graphs that are generated are the same? Nope. Because a friendship is a mutual connection if I am your friend you are my friend. A phone call is unidirectional. So we have to pay attention to what we are modeling. Sometimes a problem can be modeled in just one way but some complex domains can be modeled in various ways.
Some graph theory that you should know (you can watch the reference videos to gain more confidence):
- Node Degree
- Bipartite Graph
- Folded/Projected Bipartite Graph
- Adjacency Matrix
- Adjacency list (to resolve the sparsity problem)
- ...
Graph feature design
God features = God performance
A good set of features is the key to obtain a network that can reach good performance. For example we want to make some predictions for a set of nodes. Given the graph G=(V,E) we want to learn the function f.
Input G = (V,E)
Goal f : V -> R
Task Node classification - Predict colors!
Based on what we can see in the above image we can assume that if the degree of node n > 1 the color is green. Choosing the right feature (the node degree) we can predict that.
Now let's dive into the three main types of features: Node features, Link features, Graph features.
Node features
The simplest node feature is the degree but there's a problem: all the nodes have the same importance. A node that has a lot of connections and that is in a central position in a spring-embedding graph representation is more important than outer nodes. So we can also consider the node centrality but there are also different types of centrality some of them are:
- Eigenvector centrality: I am as important as my neighbors are
- Betweenness centrality: I am important if there are a lot of shortest paths passing through me for every possible pair of nodes
- Closeness centrality: I am important if I have a small shortest path sum
We can also consider the local structure around the node! Some examples:
- Clustering coefficient: It measure how connected my neighbors are Note: The clustering coefficient measures the number of triangels in the ego-network. This is how the social networks work. If I know two persons they sooner or later will become friends because I will introduce them during my birthday party.
- Graphlet: A graphlet is a rooted connected non isomorphic subgraphs (or simply watch the figure below)Theese graphlets because we can build the Graphlet Degree Vector (GDV) which is a count vector of graphlets of a given node.
Link features
With these features we can find missing links or predict how links change in the temporal domain.
- Shortest path: The main distance based feature with a very simple problem, I don't know if my neighbors overlap.
- Local neighborhood overlap: Now I know if my neighbors overlap. We can use the common neighbors, Jaccard or the Adamic-Adar Index. Another limitation (there's always a limitation), this metrics are always 0 if there are no common neighbors.
- Global neighborhood overlap: We use the whole graph with the Katz index that use the adjacency matrix.
Graph features
These features can be used to characterized the whole graph.
- Kernel method: Using different kernels (knwon graphs) we can get the similarity beetween G and the kernel k. Using K kernels we extract the feature vector. A sort of bag of words representation of the graph. Here we have an example with the node degree kernel. .
- Graphlet kernel: We can use the graphlets to obtain a better feature fector but the time complexity is too high
- Weisfeiler Lehman kernel: This particular kernel resolves the previous problem because the time complexity is linear (see reference 2)
Choosing the right feature is a hard work
What if there is an alternative way? Embeddings are now used in numerous fields and why not apply them to graphs? We want to map the nodes in the embedding space. In other words, we must encode the nodes in such a way that the similarity in the source space is more or less the same in the embedding space. Now we don't have to chose the right feature but the right similarity metric. Some examples are: connection, overlapping neighborhood, DFS, BFS, random walk with DFS-BFS tradeoff.
Graph embeddings
There are different approaches to get a graph embedding, I will only list three of them:
- Get the mean value of the nodes embeddings
- Add a virtual node connected to the whole graph and get its embedding
- Anonymous walk embedding
References
- From lesson 1.1 to 3.3 of the Stanford course "CS224W: Machine Learning with Graphs" by Prof.Jure Leskovec, 2021
- WEISFEILER-LEHMAN GRAPH KERNELS, 2011