Note: The complete code for this post on machine learning with graphs can be found here. 


When we think about data science and machine learning, we think about tabular data.  We have data organized in rows and columns that could be in the form of SQL tables, CSV of JSON files, Pandas dataframes, and the like.  We think of these data points as being independent of each other.  After all, how do you create a model where one row might depend on another? 

Once we have our data and have spent a sufficient amount of time conducting exploratory data analysis (EDA) and data cleaning, we think about creating a model of some sort.  Perhaps we are trying to do binary classification and predict a label.  Or maybe we are trying to do some sort of regression and predict a value.  Regardless of the end goal, we think about creating some sort of vector or embedding where we can turn our data into a numerical input for our model.  We might have a column of “True / False” values that need to be encoded into 1’s and 0’s.  Or maybe we have text that needs to somehow be vectorized through natural language processing (NLP) techniques.  Values must be normalized in some fashion.  But the end result is that we have a vector of numbers that will be used by tools like scikit-learn, TensorFlow, PyTorch, etc. to train with.  We then can apply that trained model to unseen data and make some predictions.

This is the standard workflow and it follows the understanding that each vector is unrelated to another.  But what if that is not the case?

There are plenty of instances in the real world where our individual data points are actually linked with each other.  Consider, for example, the common data science task of predicting customer churn.  Let’s assume for the sake of this example that we are considering a social network.  We could collect information on the users such as when they joined, how many posts they have made, whether their account is verified, how many friends or followers they have, when their last post was, etc.  We can then use all of that information along with a labeled dataset to create a churn model.  However, this model has neglected to include one key concept: if my friends on a social network churn I am much more likely to churn myself.  

This brings us to network graphs: looking at which data points are connected to which others through one or a series of relationships.  But how do we actually create models based on this type of data?  This will be the subject of my ODSC West 2021 workshop this November, “Machine Learning With Graphs: Going Beyond Tabular Data.”

Working with graph data

There are a variety of packages in a variety of programming languages for interacting with graphs.  For the sake of this post, we will use Python to connect to a free Neo4j Sandbox database instance.  As a demonstration, we will use a subset of the arXiv dataset on Kaggle (the full dataset is 1.1 TB and growing), which represents a corpus of research papers.  Using Google Colab and connecting to a Kaggle account, we can retrieve this dataset through the following:

!pip install -q kaggle

from google.colab import files

# upload your kaggle.json API key here:

!mkdir ~/.kaggle
!cp kaggle.json ~/.kaggle
!chmod 600 ~/.kaggle/kaggle.json
!kaggle datasets download -d Cornell-University/arxiv
!unzip *.zip && rm *.zip

file = “./arxiv-metadata-oai-snapshot.json”

metadata  = []

lines = 100000    # 100k for testing
with open(file, ‘r’) as f:
   for line in tqdm(f):
       lines -= 1
       if lines == 0: break

df = pd.DataFrame(metadata)

Below is a subset of these columns containing the ID, submitter name, journal the paper was published in, abstract, authors list, and categories.   

After some cleaning of this dataset, we obtain the following as a Pandas dataframe:

machine learning with graphs

Suppose we wish to perform multi-class, multi-label classification to predict the tags associated with a given paper.  In traditional data science, it might make sense to take text like the title and abstract and use NLP to generate word embeddings.  However, it is clear that there are some relationships present in this dataset.  First, there is a connection between the authors of the same paper.  Second, there are connections between the papers themselves and the categories that they are published in.  We could consider that as a graph schema such as:

In such a schema, we could anticipate that there could be multiple author nodes connected to a paper node, and a paper node could be connected to multiple category nodes.  Additionally, an author could write multiple papers in the data set, which could have differing categories and, naturally, there will be many papers associated with each category.  We could also connect two authors to indicate that they have co-authored a paper.  (Note that this was not done in the sample code provided above.)  If we consider a small subset of just a single category, we would find a graph that looks like this:

machine learning with graphs

where the dark blue nodes represent the authors, the light blue nodes represent the papers, and the green node at the center is a single tag.

While images like the above are useful for EDA, we need to find some way to convert at least the nodes in this graph to an embedding for use in modeling.  However, unlike tabular data, the embeddings for each node should reflect the nodes that it is connected to.  There are a variety of ways that such embeddings can be created.  Simply conducting a random walk around the graph, recording what nodes are encountered along the way, is a popular way to do it.  This is the basis of the FastRP embedding algorithm.  Another popular method, node2vec, couples a skip-gram approach to a random walk, similar to how the popular word2vec algorithm works in NLP.  Using these algorithms it is possible to create an embedding for each node.  In addition to a whole host of other embedding methods for nodes similar to these two, there are more complicated ones that are beyond the scope of this post that can also use node properties as well for embedding creation.  Some embedding techniques even go so far as to create embeddings for whole subgraphs or even graphs!

These embeddings can then be treated as the inputs to a data science or machine learning model, just as any other embedding from the traditional, tabular-based approach.  The difference is that they have captured, for individual rows of the model, additional information about the relationships between the individual data points.

Wrapping up

In this post, I talked about the reasons for considering the relationships between data points versus just looking at tabular data.  In my upcoming ODSC West 2021 talk, I will go into more detail on when to know that you have a problem that will lend itself well to this approach.  I will also walk you through how to generate these embeddings and create some basic machine learning models with them.  I hope to see you there!

About the author / ODSC West 2021 speaker on machine learning with graphs:

Dr. Clair Sullivan is currently a graph data science advocate at Neo4j, working to expand the community of data scientists and machine learning engineers using graphs to solve challenging problems.  She received her doctorate degree in nuclear engineering from the University of Michigan in 2002.  Prior to starting in tech, she worked in a variety of settings including national laboratories, the federal government, and as a professor at the University of Illinois.  She has authored 4 book chapters, over 20 peer-reviewed papers, and more than 30 conference papers.  Dr. Sullivan was the recipient of the DARPA Young Faculty Award in 2014 and the American Nuclear Society’s Mary J. Oestmann Professional Women’s Achievement Award in 2015.


Twitter: @CJLovesData1