My graph neural network journey - part 1

Graph Neural Network GNN Notes

Disclaimer: This is the first post of a series; it is based on a Stanford youtube course that I will cite below in the reference section. Those posts will help me to better understand what I am studying and hopefully provide you with a high-level description of GNN and how they work. I will re-elaborate and condense what is contained in that course but for more specific notions please watch the original material and if I make some "Grezza" (roman slang word to say error/mistake) let me know through my contacts.

Graphs, graphs, graphs

Introduction

That course is based on graphs so let's briefly talk about them. It is well known that there are a lot of things that could be modeled through a graph, for example, a social network, but what is a graph? Simply a composition of nodes and edges. Now I could talk about the plethora of graph types but I won't (keep it simple).

Various tasks can be applied to graphs to reach a given goal. These tasks can be divided into 4 levels: node, edge, community, and graph. If we want to categorize a user or an item we talk about node classification, so a node-level task. Do you want to know if someone will buy that idem? LINK PREDICTION! An edge-level task. What about the new drug that they have discovered? Graph classification (graph-level) and so on...

A rappresentation of the task levels of a graph.

Concerning my research interests, the most relevant task is link prediction. This task can be represented by a bipartite graph (users-items) where the edges represent the item bought/rated/viewed by a user. The main task is to predict the missing edges which are the recommendations.

Bipartite graph rappresentation of a recommender system.

How to rapresent something with a graph?

To represent a domain with a graph we have to use nodes and edges, as mentioned previously, but what if I told you to represent some phone calls or a social network? Do you think that the graphs that are generated are the same? Nope. Because a friendship is a mutual connection if I am your friend you are my friend. A phone call is unidirectional. So we have to pay attention to what we are modeling. Sometimes a problem can be modeled in just one way but some complex domains can be modeled in various ways.

Some graph theory that you should know (you can watch the reference videos to gain more confidence):

Graph feature design

God features = God performance

A good set of features is the key to obtain a network that can reach good performance. For example we want to make some predictions for a set of nodes. Given the graph G=(V,E) we want to learn the function f.

Input G = (V,E)

Goal f : V -> R

Task Node classification - Predict colors! A simple graph with 5 nodes and 6 edges.

Based on what we can see in the above image we can assume that if the degree of node n > 1 the color is green. Choosing the right feature (the node degree) we can predict that.

Now let's dive into the three main types of features: Node features, Link features, Graph features.

Node features

The simplest node feature is the degree but there's a problem: all the nodes have the same importance. A node that has a lot of connections and that is in a central position in a spring-embedding graph representation is more important than outer nodes. So we can also consider the node centrality but there are also different types of centrality some of them are:

We can also consider the local structure around the node! Some examples:

With these features we can find missing links or predict how links change in the temporal domain.

Graph features

These features can be used to characterized the whole graph.

Choosing the right feature is a hard work

What if there is an alternative way? Embeddings are now used in numerous fields and why not apply them to graphs? We want to map the nodes in the embedding space. In other words, we must encode the nodes in such a way that the similarity in the source space is more or less the same in the embedding space. Now we don't have to chose the right feature but the right similarity metric. Some examples are: connection, overlapping neighborhood, DFS, BFS, random walk with DFS-BFS tradeoff.

Graph embeddings

There are different approaches to get a graph embedding, I will only list three of them:

  1. Get the mean value of the nodes embeddings
  2. Add a virtual node connected to the whole graph and get its embedding
  3. Anonymous walk embedding

References

  1. From lesson 1.1 to 3.3 of the Stanford course "CS224W: Machine Learning with Graphs" by Prof.Jure Leskovec, 2021
  2. WEISFEILER-LEHMAN GRAPH KERNELS, 2011