GNN Hands On 01
GNN Hands On 01
GNN Hands On 01
Hands-on Graph Neural Networks with PyTorch Geometric (1): Cora Dataset
How to handle pytorch geometric and networkx
Characteristics of the Cora dataset
How to effectively visualize data
1 import os
2 import collections
3 import numpy as np
4 import pandas as pd
5 import matplotlib.pyplot as plt
6 import seaborn as sns
7 import scipy.sparse as sp
8 import torch
9 from torch import Tensor
10 import torch_geometric
11 from torch_geometric.utils import to_networkx
12 from torch_geometric.datasets import Planetoid
13 import networkx as nx
14 from networkx.algorithms import community
cuda
Cora Dataset
The Cora dataset is a well-known dataset in the field of graph research. This consists of 2708 scientific publications classified into one of seven
classes. The citation network consists of 5429 links. Each publication in the dataset is described by a 0/1-valued word vector indicating the
absence/presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words.
As a side note, there is a service that displays a network of papers connected by citation relations. This is very useful when looking for related
studies. See here for details.
First, download the dataset by running the command below. In this article we will work with the data using pytorch geometric and networkx.
https://colab.research.google.com/drive/1D7bRYKu44NCA-XPIQElQWdtmlvcjrlwo#printMode=true 1/11
11/26/23, 11:36 PM GNN HANDS 001 - Colaboratory
1 dataset = Planetoid(root=data_dir, name='Cora')
2 data = dataset[0]
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!
Nodes
The Cora dataset contains 2708 papers, which are represented as nodes in the graph.
Edges
The papers in the Cora dataset have 5429 citation connections, which are represented as edges in the graph. The edge information is unique to
graph data.
The number of edges seems to be 10556. Let’s find out why the number of edges is twice as large as 5429.
The first line of code confirms that there are no nodes not connected by edges, the second line of code shows that there are no self-loops, and
the third line of code shows that edges are not directional. This means that the edge count is double the actual count because of the bi-
directional edge information included.
Let’s see how edge information is stored. Let's look at the edge held by the 30th node as an example.
The first line of code confirms that there are no nodes not connected by edges, the second line of code shows that there are no self-loops, and
the third line of code shows that edges are not directional. This means that the edge count is double the actual count because of the bi-
directional edge information included.
Let's see how edge information is stored. Let's look at the edge held by the 30th node as an example.
1 edge_index = data.edge_index.numpy()
2 print(edge_index.shape)
3 edge_example = edge_index[:, np.where(edge_index[0]==30)[0]]
4 edge_example
(2, 10556)
array([[ 30, 30, 30, 30, 30, 30],
[ 697, 738, 1358, 1416, 2162, 2343]])
We have obtained data in the form of pairs of nodes to which the edges are connected.
https://colab.research.google.com/drive/1D7bRYKu44NCA-XPIQElQWdtmlvcjrlwo#printMode=true 2/11
11/26/23, 11:36 PM GNN HANDS 001 - Colaboratory
1 node_example = np.unique(edge_example.flatten())
2 plt.figure(figsize=(10, 6))
3 G = nx.Graph()
4 G.add_nodes_from(node_example)
5 G.add_edges_from(list(zip(edge_example[0], edge_example[1])))
6 nx.draw_networkx(G, with_labels=False)
7 plt.axis('off')
8 plt.show()
Node Degree
Degree in graph theory means the number of edges joining a vertex (node) in a graph. We saw earlier that each node always has an edge, so
how many edges does each node have on average?
We found that the average node degree is 3.9. You may have thought it was surprisingly low. We can check the overall distribution by drawing a
histogram of the degree.
1 G = to_networkx(data, to_undirected=True)
2 degrees = [val for (node, val) in G.degree()]
3 display(pd.DataFrame(pd.Series(degrees).describe()).transpose().round(2))
4 print(len(degrees))
5 print(sum(degrees))
6 plt.figure(figsize=(10, 6))
7 plt.hist(degrees, bins=50)
8 plt.xlabel("node degree")
9 plt.show()
https://colab.research.google.com/drive/1D7bRYKu44NCA-XPIQElQWdtmlvcjrlwo#printMode=true 3/11
11/26/23, 11:36 PM GNN HANDS 001 - Colaboratory
High degree means that they are connected to many nodes (papers). In other words, nodes with high degree are likely to be important.
Remember, when looking for papers, you can always infer how good or bad a paper is by looking at how many times it has been cited.
Let's plot the graph to see where the top 10 nodes with the highest degree are located.
1 G = to_networkx(data, to_undirected=True)
2 pos = nx.spring_layout(G, seed=42)
3 cent = nx.degree_centrality(G)
4 node_size = list(map(lambda x: x * 800, cent.values()))
5 cent_array = np.array(list(cent.values()))
6 threshold = sorted(cent_array, reverse=True)[10]
7 print("threshold", threshold)
8 cent_bin = np.where(cent_array >= threshold, 1, 0.1)
threshold 0.011821204285186553
1 plt.figure(figsize=(22, 16))
2 nodes = nx.draw_networkx_nodes(G, pos, node_size=node_size,
3 cmap=plt.cm.plasma,
4 node_color=cent_bin,
5 nodelist=list(cent.keys()),
6 alpha=cent_bin)
7 edges = nx.draw_networkx_edges(G, pos, width=0.50, alpha=0.3)
8 plt.axis('off')
9 plt.show()
https://colab.research.google.com/drive/1D7bRYKu44NCA-XPIQElQWdtmlvcjrlwo#printMode=true 4/11
11/26/23, 11:36 PM GNN HANDS 001 - Colaboratory
The top 10 nodes with the highest degree are represented by yellow dots, and the other nodes are represented by gray dots. The yellow dots are
proportional to the size of the degree. You can see that the yellow dots are all located in the central part of the network.
Features
The papers in the Cora dataset have 1433 features.
For each of the 1433 words, the number of features is expressed as 0 and 1 if the word is included or not included in the paper.
Note that we are now looking at the node features. Edges may also have feature values (edge features), but they are not included in the Cora
dataset.
1 # Let's display some of the features, and you can see that they are composed of 0s and 1s.
2 print(len(data.x[0]))
3 data.x[0][:20]
1433
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 1.])
Classes
The papers in the Cora dataset are labeled with 7 different labels.
Number of classes: 7
Let’s display a portion of the class, and we can see that it consists of integers between 0 and 6. Each number corresponds to a subject as
follows. See here for details.
https://colab.research.google.com/drive/1D7bRYKu44NCA-XPIQElQWdtmlvcjrlwo#printMode=true 5/11
11/26/23, 11:36 PM GNN HANDS 001 - Colaboratory
1 label_dict = {
2 0: "Theory",
3 1: "Reinforcement_Learning",
4 2: "Genetic_Algorithms",
5 3: "Neural_Networks",
6 4: "Probabilistic_Methods",
7 5: "Case_Based",
8 6: "Rule_Learning"}
9 data.y[:10]
tensor([3, 4, 4, 0, 3, 2, 0, 3, 3, 2])
Often the number of classes is not equal. Let's find out the number of each class.
1 counter = collections.Counter(data.y.numpy())
2 counter = dict(counter)
3 print(counter)
4 count = [x[1] for x in sorted(counter.items())]
5 plt.figure(figsize=(10, 6))
6 plt.bar(range(7), count)
7 plt.xlabel("class", size=20)
8 plt.show()
The highest number of classes is 818 in class 3, and the lowest number is 180 in class 6. We need to be careful when training machine learning
models.
Next, draw a network diagram to see if the classes are distributed coherently.
1 G = to_networkx(data, to_undirected=True)
2 node_color = []
3 nodelist = [[], [], [], [], [], [], []]
4 colorlist = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00', '#ffff33', '#a65628']
5 labels = data.y
6 for n, i in enumerate(labels):
7 node_color.append(colorlist[i])
8 nodelist[i].append(n)
9 pos = nx.spring_layout(G, seed = 42)
https://colab.research.google.com/drive/1D7bRYKu44NCA-XPIQElQWdtmlvcjrlwo#printMode=true 6/11
11/26/23, 11:36 PM GNN HANDS 001 - Colaboratory
1
2 plt.figure(figsize = (14, 16))
3 labellist = list(label_dict.values())
4 for num, i in enumerate(zip(nodelist, labellist)):
5 n, l = i[0], i[1]
6 nx.draw_networkx_nodes(G, pos, nodelist=n, node_size = 5, node_color = colorlist[num], label=l)
7 nx.draw_networkx_edges(G, pos, width = 0.25)
8 plt.legend(bbox_to_anchor=(1, 1), loc='upper left')
9 plt.show()
It is a little difficult to see because it is plotted in two dimensions, but it looks as if the classes are somewhat grouped together. We will analyze
this point from a different angle in the next section.
Homophily
Nodes with the same characteristics are often connected. This property is called homophily. For the seven classes we looked at earlier, we will
see how many nodes of the same class are connected by edges and vice versa.
https://colab.research.google.com/drive/1D7bRYKu44NCA-XPIQElQWdtmlvcjrlwo#printMode=true 7/11
11/26/23, 11:36 PM GNN HANDS 001 - Colaboratory
1 labels = data.y.numpy()
2 connected_labels_set = list(map(lambda x: labels[x], data.edge_index.numpy()))
3 connected_labels_set = np.array(connected_labels_set)
4 def add_missing_keys(counter, classes):
5 for x in classes:
6 if x not in counter.keys():
7 counter[x] = 0
8 return counter
9 label_connection_counts = []
10 for i in range(7):
11 print(f"label: {i}")
12 connected_labels = connected_labels_set[:, np.where(connected_labels_set[0] == i)[0]]
13 print(connected_labels.shape[1], "edges")
14 counter = collections.Counter(connected_labels[1])
15 counter = dict(counter)
16 print(counter)
17 counter = add_missing_keys(counter, range(7))
18 items = sorted(counter.items())
19 items = [x[1] for x in items]
20 label_connection_counts.append(items)
21 label_connection_counts = np.array(label_connection_counts)
label: 0
1527 edges
{0: 1068, 1: 32, 3: 161, 6: 80, 5: 75, 4: 88, 2: 23}
label: 1
1029 edges
{1: 818, 3: 67, 0: 32, 5: 28, 2: 62, 4: 20, 6: 2}
label: 2
1826 edges
{2: 1654, 3: 53, 1: 62, 5: 30, 0: 23, 4: 2, 6: 2}
label: 3
2838 edges
{3: 2350, 2: 53, 4: 137, 5: 54, 0: 161, 6: 16, 1: 67}
label: 4
1592 edges
{4: 1320, 3: 137, 0: 88, 1: 20, 6: 6, 5: 19, 2: 2}
label: 5
1086 edges
{2: 30, 5: 834, 0: 75, 3: 54, 1: 28, 4: 19, 6: 46}
label: 6
658 edges
{6: 506, 5: 46, 0: 80, 4: 6, 3: 16, 1: 2, 2: 2}
1 plt.figure(figsize=(9, 7))
2 plt.rcParams["font.size"] = 13
3 hm = sns.heatmap(label_connection_counts, annot=True, cmap='hot_r', cbar=True, square=True)
4 plt.xlabel("class",size=20)
5 plt.ylabel("class",size=20)
6 plt.tight_layout()
7 plt.show()
https://colab.research.google.com/drive/1D7bRYKu44NCA-XPIQElQWdtmlvcjrlwo#printMode=true 8/11
11/26/23, 11:36 PM GNN HANDS 001 - Colaboratory
We can see that there are a great many nodes that are connected to each other belonging to the same class.
By dividing the sum of the diagonal components of the matrix by the sum of all components, we calculate the percentage of edges connected
within the same class.
1 label_connection_counts.diagonal().sum() / label_connection_counts.sum()
0.8099658961727927
It seems that about 81% of the edges are connected within the same class.
1 def scaling(array):
2 return array / sum(array)
3 label_connection_counts_scaled = np.apply_along_axis(scaling, 1, label_connection_counts)
1 plt.figure(figsize=(9, 7))
2 plt.rcParams["font.size"] = 13
3 hm = sns.heatmap(
4 label_connection_counts_scaled,
5 annot=True,
6 cmap='hot_r',
7 fmt="1.2f",
8 cbar=True,
9 square=True)
10 plt.xlabel("class",size=20)
11 plt.ylabel("class",size=20)
12 plt.tight_layout()
13 plt.show()
https://colab.research.google.com/drive/1D7bRYKu44NCA-XPIQElQWdtmlvcjrlwo#printMode=true 9/11
11/26/23, 11:36 PM GNN HANDS 001 - Colaboratory
Excellent! We can see that for all classes, the highest number of edges are tied to the same class. This is also most true for Class 2, where
about 91% of the edges are joined within the same class. On the other hand, the trend is relatively weak in Class 0, with about 70% of the edges
being within the same class.
The data was split into 140 training data, 500 validation data, and 1000 test data. However, the total of these does not add up to 2708 cases.
Let’s check which data is used and which data is not.
1 split_type_array = np.zeros(data.num_nodes)
2 split_type_array[np.where(data.train_mask == True)[0]] = 1
3 split_type_array[np.where(data.val_mask == True)[0]] = 2
4 split_type_array[np.where(data.test_mask == True)[0]] = 3
5 split_type_array
6 plt.scatter(range(2708), split_type_array)
7 plt.xlabel("index")
8 plt.show()
We plot the x-axis as index and the y-axis as 0 for unused data, 1 for training data, 2 for validation data, and 3 for test data. It is an odd split, but
it appears that the data is split as above.
https://colab.research.google.com/drive/1D7bRYKu44NCA-XPIQElQWdtmlvcjrlwo#printMode=true 10/11
11/26/23, 11:36 PM GNN HANDS 001 - Colaboratory
1 titles = ["Training", "Validation", "Test"]
2 fig, axes = plt.subplots(ncols=3, figsize=(21, 6))
3 for i in range(3):
4 counter = collections.Counter(data.y.numpy()[np.where(split_type_array == i + 1)[0]])
5 counter = dict(counter)
6 print(titles[i], counter)
7 #count = [x[1] for x in sorted(counter.items())]
8 # plt.figure(figsize=(10, 6))
9 axes[i].bar(range(7), count)
10 axes[i].set_xlabel("class", size=20)
11 axes[i].set_title(titles[i])
12 plt.show()
https://colab.research.google.com/drive/1D7bRYKu44NCA-XPIQElQWdtmlvcjrlwo#printMode=true 11/11