作者:Steeve Huang
編譯:ronghuaiyang
導讀
給大家介紹目前非常熱門的圖神經網絡,包括基礎和兩個常用算法,DeepWalk和GraphSage。
近年來,圖神經網絡(GNN)在社交網絡、知識圖譜、推薦系統甚至生命科學等各個領域得到了越來越廣泛的應用。GNN具有對圖中節點間依賴關係建模的強大功能,使得圖分析相關研究領域取得了突破。本文會介紹圖神經網絡的基本原理,以及兩種更高級的算法,DeepWalk和GraphSage。
圖
在討論GNN之前,讓我們先了解什麼是Graph。在計算機科學中,圖是由頂點和邊兩部分組成的數據結構。一個圖G可以由它所包含的頂點V和邊E的集合很好地描述。
根據頂點之間是否存在方向依賴關係,邊可以是有向的,也可以是無向的。
有向圖
這些頂點通常稱為節點。在本文中,這兩個術語是可以互換的。
圖神經網絡
圖神經網絡是一種直接作用於圖結構上的神經網絡。GNN的一個典型應用是節點分類。本質上,圖中的每個節點都與一個標籤相關聯,我們希望在沒有ground-truth的情況下預測節點的標籤。本節將說明本文中描述的算法。第一個提出的GNN常常被稱為原始GNN。
在節點分類問題設置中,每個節點v的特徵是x_v,並與一個ground-truth標籤t_v關聯。給定一個部分標記的圖G,目標是利用這些標記的節點來預測未標記的標籤。它學習用一個d維向量(狀態)h_v表示每個節點,其中包含其鄰域的信息。具體地說,
其中x_co[v]表示與v相連的邊的特徵,h_ne[v]表示與v相鄰節點的嵌入,x_ne[v]表示與v相鄰節點的特徵。函數f是將這些輸入投射到d維空間的轉換函數。因為我們正在為h_v尋找一個惟一的解,所以我們可以應用Banach定點定理並將上面的等式重寫為疊代更新過程。這種操作通常稱為消息傳遞或鄰居聚合。
H和X分別表示所有H和X的級聯。
H和X分別表示所有H和X的級聯。
通過將狀態h_v和特徵x_v傳遞給輸出函數g來計算GNN的輸出。
這裡的f和g都可以解釋為前饋全連接神經網絡。L1損失可以直接表述為:
可以通過梯度下降來優化。
然而,有文章指出,GNN的這一原始方法存在三個主要限制:
- 如果放鬆「不動點」的假設,就有可能利用多層感知器來學習更穩定的表示,並消除疊代更新過程。這是因為,在原方案中,不同的疊代使用相同的轉換函數f的參數,而MLP不同層中不同的參數允許分層特徵提取。
- 它不能處理邊緣信息(例如,知識圖中不同的邊緣可能表示節點之間不同的關係)
- 不動點會阻礙節點分布的多樣化,因此可能不適合學習表示節點。
已經提出了幾個GNN的變體來解決上述問題。然而,它們沒有被涵蓋,因為它們不是本文的重點。
DeepWalk
DeepWalk是第一個提出以無監督方式學習節點嵌入的算法。就訓練過程而言,它非常類似於單詞嵌入。其動機是圖中節點和語料庫中單詞的分布遵循冪律,如下圖所示:
算法包括兩個步驟:
- 在圖中的節點上執行隨機漫步以生成節點序列
- 根據步驟1生成的節點序列,運行skip-gram,學習每個節點的嵌入
在隨機遊走的每個時間步長上,下一個節點均勻地從前一個節點的鄰居中採樣。然後將每個序列截斷為長度2|w| + 1的子序列,其中w表示skip-gram的窗口大小。
本文採用層次softmax 算法,解決了節點數量大、計算量大的softmax問題。要計算每個單獨輸出元素的softmax值,我們必須計算所有元素k的所有e^xk。
因此,原始softmax的計算時間為O(|V|),其中V表示圖中頂點的集合。
層次softmax利用二叉樹來處理該問題。在這個二叉樹中,所有的葉子(上圖中的v1 v2…v8)都是圖中的頂點。在每個內部節點中,都有一個二進位分類器來決定選擇哪條路徑。要計算給定頂點v_k的機率,只需計算從根節點到左節點的路徑上的每個子路徑的機率v_k。由於每個節點的子節點的機率之和為1,所以所有頂點的機率之和等於1的性質在層次softmax中仍然成立。現在一個元素的計算時間減少到O(log|V|),因為二叉樹的最長路徑以O(log(n))為界,其中n是葉子的數量。
經過DeepWalk GNN的訓練,模型學習到每個節點的良好表示,如下圖所示。不同的顏色表示輸入圖中不同的標籤。我們可以看到,在輸出圖中(2維嵌入),具有相同標籤的節點聚集在一起,而具有不同標籤的大多數節點被正確地分離。
然而,DeepWalk的主要問題是它缺乏泛化的能力。每當一個新節點出現時,它都必須對模型進行重新訓練,才可以表示這個節點。因此,這種GNN不適用於圖中節點不斷變化的動態圖。
GraphSage
GraphSage提供了一個解決上述問題的解決方案,以歸納的方式學習每個節點的嵌入。具體地說,每個節點由其鄰域的聚合表示。因此,即使在圖中出現了訓練過程中沒有出現過的新的節點,也可以用相鄰的節點來表示。下面是GraphSage的算法。
外循環表示更新疊代次數,h^k_v表示更新疊代時節點v的隱向量k。在每次更新疊代中,根據一個聚集函數、前一次疊代中v和v鄰域的隱向量以及權矩陣W^k對h^k_v進行更新。本文提出了三個聚合函數:
1. Mean aggregator:
mean aggregator取一個節點及其所有鄰域的隱向量的平均值。
與原始方程相比,它刪除了上面偽代碼第5行中的連接操作。這種操作可以看作是一種「跳躍連接」,本文稍後的部分證明了這種連接在很大程度上提高了模型的性能。
2. LSTM aggregator:
由於圖中的節點沒有任何順序,它們通過遍歷這些節點來隨機分配順序。
3. Pooling aggregator:
這個操作符在相鄰的集合上執行一個元素池化函數。下面是最大池化的例子:
可以用均值池化或任何其他對稱池化函數替換。池化聚合器性能最好,而均值池化聚合器和最大池化聚合器性能相近。本文使用max-pooling作為默認的聚合函數。
損失函數定義如下:
其中u和v共出現在固定長度的隨機遊動中,v_n是與u不共出現的負樣本。這種損失函數鼓勵距離較近的節點進行類似的嵌入,而距離較遠的節點則在投影空間中進行分離。通過這種方法,節點將獲得越來越多的關於其鄰域的信息。
GraphSage通過聚合其附近的節點,為不可見的節點生成可表示的嵌入。它允許將節點嵌入應用於涉及動態圖的領域,其中圖的結構是不斷變化的。例如,Pinterest採用GraphSage的擴展版本PinSage作為內容發現系統的核心。
總結
你學習了圖形神經網絡、DeepWalk和GraphSage的基礎知識。GNN在複雜圖形結構建模方面的能力確實令人吃驚。鑒於其有效性,我相信在不久的將來,GNN將在人工智慧的發展中發揮重要的作用。
英文原文:https://towardsdatascience.com/a-gentle-introduction-to-graph-neural-network-basics-deepwalk-and-graphsage-db5d540d50b3
請長按或掃描二維碼關注本公眾號