r에서 Spinner로 만드는 Graph Net

R아두면 쓸데있는 패키지 이야기 05 Spinner package
R package
GNN
Author

chichead

Published

April 3, 2023

spinner package

오늘 소개할 R package는 Spinner package입니다. 로고에도 그려진 것처럼 Spinner는 실을 만드는 방적기, 방적공을 의미합니다. Spinner package는 토치(Torch)를 기반으로 Graph Net을 구현해주는 패키지입니다. 자세한 내용은 Spinner package를 만든 Giancarlo VercellinoRpub을 참조하세요.


Graph Layer

Graph Net은 그래프(혹은 구조화된 데이터)를 처리하기 위해 설계된 신경망 아키텍처입니다. Distill의 <A Gentle Introduction to Graph Neural Networks> 논문을 정리해보면서 이미 Graph Net을 살펴본 바 있습니다. 이웃한 노드나 엣지가 서로 정보를 교환해서 각각의 노드의 상태를 업데이트하는 Massage Passing을 이용한 Layer를 다루었죠. 그 중에서 노드에서 노드로, 엣지에서 엣지로, 노드에서 엣지로, 엣지에서 노드로, 혹은 이 4가지 방법을 모두 결합해서 마치 천을 직조하듯 구성한 Weave Layer도 살펴봤습니다.

기본적인 Graph Net의 연산과 마찬가지로 Spinner package는 그래프의 노드와 엣지 간에 정보를 전파하는 메시지 전달 연산(Message-Passing Operations), 그리고 수신된 메시지를 기반으로 새로운 노드, 엣지의 Feature를 계산하는 업데이트 함수로 구성됩니다.



Data Preparation

Spinner package는 그래프 샘플링(Graph Sampling)과 특징 추출(Feature Extraction)이라는 두 가지 작업에서부터 시작됩니다. 거대한 그래프의 경우에는 샘플링 값을 설정하고 Graph Density Threshold를 조정해서 하위 그래프를 샘플링할 수 있습니다. 그런 다음 Spinner package는 특징(Feature)을 추출합니다. 그래프에 Feature 값이 없는 경우에는 알고리즘은 Null value, 인접 임베딩 또는 라플라시안 임베딩을 이용해서 New Feature를 계산합니다. Null value는 관련 정보가 없다는 의미이고, 인접 임베딩은 그래프의 인접 행렬을 통해 노드 간의 관계를 포착합니다. 라플라시안 임베딩은 라플라시안 행렬을 분해하여 로컬 및 글로벌 속성을 포착합니다. 결측값이 있는 특징의 경우 empirical distribution을 사용하여 무작위 대입을 수행합니다.


Graph Net Layers

Spinner가 생성한 레이어는 Message Passing과 graph-independent forward network로 구성됩니다. Message Passing에서 그래프의 각 노드는 인접 노드로부터 메시지를 받고, 받은 메시지는 노드의 Feature 표현을 업데이트하는 데 사용됩니다. update_order 조건을 사용하면 다양한 옵션을 사용할 수 있죠. 업데이트의 조합은 선형 변환을 기반으로 합니다. graph-independent forward network는 업데이트된 Feature 표현을 가져와 DNN 변환을 적용합니다. 이 과정은 선택한 수의 레이어에 대해 반복되므로 알고리즘이 기능을 세분화하고 그래프의 더 복잡한 표현을 구축할 수 있습니다.


Output Transformation

Graph Net Layers가 완료되면 optional skip shortcut을 적용할 수 있습니다. skip shortcut을 사용하면 알고리즘이 특정 레이어를 건너뛰고 입력을 출력 레이어에 직접 연결하여 알고리즘의 효율성을 개선할 수 있죠. 출력 단계에선 Regression Tasks에 대한 선형 변환(Continuous range에 매핑하는 선형 변환 / Label Feature의 경우엔 확률 분포에 매핑하는 softmax/sigmoid activation)이 이뤄집니다. 마지막 단계에선 주어진 그래프 특징에 대한 예측 값 또는 확률을 나타내는 그래프 넷 알고리즘의 최종 출력을 생성합니다.


Examples with a graph

이제부터 본격적으로 그래프를 가지고 진행해보겠습니다. r에서 그래프를 그리기 위해 igraph package와 ggplot2 환경에서 그래프를 그리게 해주는 ggnetwork package를 이용하겠습니다. 우선 100개의 노드를 가지고 있는 작은 더미 그래프를 만들어보죠. 그래프에는 노드와 엣지에 각각 2개의 Feature를 넣어두겠습니다. 먼저 하나는 정규화된 연결 중심성(Degree Centrality, 한 노드에 연결된 엣지의 개수)이고, 또 하나는 cut of betweenness statstics입니다. 컨텍스트/글로벌 그래프에 대한 특징 값은 따로 없습니다.

library(igraph)
library(ggplot2)
library(ggnetwork)

set.seed(1004)
dummy_graph <- random.graph.game(100, 0.05) # 100개의 노드, 노드간 엣지 연결 확률 0.05인 그래프 생성

# Feature 넣어주기
V(dummy_graph)$node_feat1 <- degree(dummy_graph, normalized = T) + runif(50)
V(dummy_graph)$node_feat2 <- as.character(cut(betweenness(dummy_graph, normalized = T), 3))

E(dummy_graph)$edge_feat1 <- degree(line.graph(dummy_graph), normalized = T) + runif(ecount(dummy_graph))
E(dummy_graph)$edge_feat2 <- as.character(cut(betweenness(line.graph(dummy_graph), normalized = T), 2))

ggplot(ggnetwork(dummy_graph), aes(x = x, y = y, xend = xend, yend = yend)) +
  geom_edges(aes(color = edge_feat2)) + 
  geom_nodes(aes(size = node_feat1, color = node_feat2)) + 
  theme_void() + 
  guides(size = 'none', color = 'none') + 
  scale_color_manual(values = viridis::viridis(15, direction = -1, option = "B")[c(3, 6, 9, 12, 15)])


spinner 함수에 필요한 최소한의 파라미터는 그래프, 예측 대상(노드나 엣지), 노드, 에지 및 컨텍스트 Feature에 대한 레이블입니다. (위에서도 이야기 했지만 Feature가 없는 경우엔 임베딩 방법을 사용하여 새로운 특징을 계산합니다. Feature가 없는 경우 기본 임베딩 크기는 5이고 relative arguments를 사용하여 노드, 엣지 및 컨텍스트에 대해 수정할 수 있습니다) 이번 연습에서는 모든 노드와 엣지의 Feature를 사용하고(기본 옵션으로 컨텍스트를 5개의 0 벡터로 초기화) 엣지에 예측 타깃을 설정하여 2-folds, 3-repetitions의 cross-validation을 해보겠습니다.

library(spinner)

example1 <- spinner(dummy_graph, target = "edge", 
                    node_labels = c("node_feat1", "node_feat2"), 
                    edge_labels = c("edge_feat1", "edge_feat2"), 
                    holdout = 0.6, 
                    reps = 3, 
                    folds = 2, 
                    n_layers = 1)
epoch:  10    Train loss:  0.6820657    Val loss:  0.6711463 
epoch:  20    Train loss:  0.6089707    Val loss:  0.6455721 
epoch:  30    Train loss:  0.6605015    Val loss:  0.6717189 
epoch:  40    Train loss:  0.5759389    Val loss:  0.6358511 
early stop at epoch:  41    Train loss:  0.6066641    Val loss:  0.6683025 
epoch:  10    Train loss:  0.7671983    Val loss:  0.661452 
epoch:  20    Train loss:  0.7404004    Val loss:  0.6647233 
epoch:  30    Train loss:  0.7611908    Val loss:  0.7213398 
early stop at epoch:  39    Train loss:  0.7686964    Val loss:  0.6944773 
epoch:  10    Train loss:  0.6721007    Val loss:  0.6914219 
epoch:  20    Train loss:  0.7303004    Val loss:  0.7528074 
epoch:  30    Train loss:  0.6592697    Val loss:  0.7057204 
epoch:  40    Train loss:  0.698299    Val loss:  0.6622039 
early stop at epoch:  44    Train loss:  0.6732623    Val loss:  0.7268654 
epoch:  10    Train loss:  0.6426511    Val loss:  0.7335425 
epoch:  20    Train loss:  0.7605699    Val loss:  0.7579686 
epoch:  30    Train loss:  0.7845948    Val loss:  0.7324713 
epoch:  40    Train loss:  0.7046131    Val loss:  0.7116204 
early stop at epoch:  48    Train loss:  0.5067006    Val loss:  0.7423382 
epoch:  10    Train loss:  0.5613942    Val loss:  0.6729948 
epoch:  20    Train loss:  0.6164793    Val loss:  0.6531799 
epoch:  30    Train loss:  0.6119072    Val loss:  0.683926 
early stop at epoch:  32    Train loss:  0.6206366    Val loss:  0.6946394 
epoch:  10    Train loss:  0.896782    Val loss:  0.7952279 
epoch:  20    Train loss:  0.9012119    Val loss:  0.7725604 
epoch:  30    Train loss:  0.493989    Val loss:  0.7581556 
epoch:  40    Train loss:  0.63601    Val loss:  0.7295729 
epoch:  50    Train loss:  0.9077    Val loss:  0.7525882 
epoch:  60    Train loss:  0.4619842    Val loss:  0.7336836 
epoch:  70    Train loss:  0.9050267    Val loss:  0.7550592 
epoch:  80    Train loss:  0.8885249    Val loss:  0.7225138 
epoch:  90    Train loss:  0.8951436    Val loss:  0.7863429 
epoch:  100    Train loss:  0.892521    Val loss:  0.7861573 
epoch:  10    Train loss:  0.7193506    Val loss:  0.6954241 
epoch:  20    Train loss:  0.6814125    Val loss:  0.6942275 
epoch:  30    Train loss:  0.6652368    Val loss:  0.7037159 
early stop at epoch:  34    Train loss:  0.6795753    Val loss:  0.7075669 
time: 14.264 sec elapsed


함수를 돌리면 나오는 결과값에는 출력에는 그래프(원본 or 샘플링), 모델 설명 및 요약, 새 그래프 데이터에 대한 예측, 교차 검증 및 요약 오류, 손실 함수에 대한 플롯(최종 학습 및 테스트용) 및 시간 로그가 포함되어 있습니다.

example1$model_description
[1] "model with 1 GraphNet layers, 1 classification tasks and 1 regression tasks (1029 parameters)"
example1$model_summary
$GraphNetLayer1
An `nn_module` containing 1,027 parameters.

── Modules ─────────────────────────────────────────────────────────────────────
• context_to_edge: <nn_pooling_from_context_to_edges_layer> #18 parameters
• context_to_node: <nn_pooling_from_context_to_nodes_layer> #24 parameters
• edge_to_context: <nn_pooling_from_edges_to_context_layer> #20 parameters
• edge_to_node: <nn_pooling_from_edges_to_nodes_layer> #16 parameters
• node_to_context: <nn_pooling_from_nodes_to_context_layer> #25 parameters
• node_to_edge: <nn_pooling_from_nodes_to_edges_layer> #15 parameters
• node_fusion: <nn_linear> #3 parameters
• edge_fusion: <nn_linear> #3 parameters
• context_fusion: <nn_linear> #3 parameters
• independent_layer: <nn_graph_independent_forward_layer> #900 parameters

$classif1
An `nn_module` containing 0 parameters.

$regr1
An `nn_module` containing 2 parameters.

── Parameters ──────────────────────────────────────────────────────────────────
• weight: Float [1:1, 1:1]
• bias: Float [1:1]
example1$cv_errors
  reps folds     train validation
1    1     1 0.6066641  0.6683025
2    1     2 0.7686964  0.6944773
3    2     1 0.6732623  0.7268654
4    2     2 0.5067006  0.7423382
5    3     1 0.6206366  0.6946394
6    3     2 0.8925210  0.7861573
example1$summary_errors
                train validation.validation                  test 
            0.6795753             0.7187967             0.7075669 
example1$history + theme_minimal()