實(shí)現(xiàn)大規(guī)模圖計(jì)算的算法思路
分享嘉賓:徐瀟然 Hulu 研究員
編輯整理:莫高鼎
出品平臺:DataFunTalk
導(dǎo)讀:2017年我以深度學(xué)習(xí)研究員的身份加入Hulu,研究領(lǐng)域包括了圖神經(jīng)網(wǎng)絡(luò)及NLP中的知識圖譜推理,其中我們在大規(guī)模圖神經(jīng)網(wǎng)絡(luò)計(jì)算方向的工作發(fā)表在ICLR2020主會上,題目是——Dynamically Pruned Message Passing Networks for Large-Scale Knowledge Graph Reasoning。本次分享的話題會沿著這個(gè)方向,重點(diǎn)和大家探討一下并列出一些可以降低大規(guī)模圖計(jì)算復(fù)雜度的思路。
1. 圖神經(jīng)網(wǎng)絡(luò)使用的圖
圖神經(jīng)網(wǎng)絡(luò)這幾年特別火爆,無論學(xué)術(shù)界還是業(yè)界,大家都在考慮用圖神經(jīng)網(wǎng)絡(luò)。正因?yàn)閳D神經(jīng)網(wǎng)絡(luò)的應(yīng)用面很廣,所用的圖各種各樣都有,簡單分類如下:
① 根據(jù)圖與樣本的關(guān)系
-
全局圖:所有樣本共用一個(gè)大圖
比如有一個(gè)大而全的知識圖譜,所做任務(wù)的每一個(gè)樣本都共用這個(gè)知識圖譜,使用來自這個(gè)知識圖譜的一部分信息。
-
實(shí)例圖:以每個(gè)樣本為中心構(gòu)建的圖
每個(gè)輸入的樣本自帶一個(gè)圖,比如要考慮一張圖片中所有物體之間的關(guān)系,這可以構(gòu)成一個(gè)物體間關(guān)系圖。換一張圖片后,就是另一張關(guān)系圖。
② 根據(jù)邊的連接密度
-
完全圖
-
稀疏圖
2. 圖神經(jīng)網(wǎng)絡(luò)與傳統(tǒng)神經(jīng)網(wǎng)絡(luò)的聯(lián)系
神經(jīng)網(wǎng)絡(luò)原本就是圖,我們大多只是提到“權(quán)重”和“層”,再細(xì)粒度一點(diǎn),會講到“單元”(即units)。但是,有圖就有節(jié)點(diǎn)和邊的概念,就看你怎么定義這個(gè)節(jié)點(diǎn)。在BERT網(wǎng)絡(luò)結(jié)構(gòu)中,輸入是一個(gè)文本序列, 預(yù)處理成一串代表word或sub-word的tokens,我們可以把這些tokens看成是圖中的nodes,這樣BERT變成了一個(gè)完全圖上的圖神經(jīng)網(wǎng)絡(luò),而且BERT網(wǎng)絡(luò)結(jié)構(gòu)的每層可以對應(yīng)到圖神經(jīng)網(wǎng)絡(luò)的一次message passing迭代。
3. 圖神經(jīng)網(wǎng)絡(luò)與傳統(tǒng)神經(jīng)網(wǎng)絡(luò)的區(qū)別
傳統(tǒng)神經(jīng)網(wǎng)絡(luò)有多個(gè)層的概念,每一層用的都是不同的參數(shù);圖神經(jīng)網(wǎng)絡(luò)只有一個(gè)圖,圖中計(jì)算通過多步迭代完成節(jié)點(diǎn)間的消息傳遞和節(jié)點(diǎn)狀態(tài)更新。這種迭代式的計(jì)算,有點(diǎn)類似神經(jīng)網(wǎng)絡(luò)的多個(gè)層,但是迭代中使用的是同一套權(quán)重參數(shù),這點(diǎn)又像單層的RNN。當(dāng)然,如果不嫌復(fù)雜,你可以堆疊多個(gè)圖,下層圖向上層圖提供輸入,讓圖神經(jīng)網(wǎng)絡(luò)有“層”的概念。
另外,圖神經(jīng)網(wǎng)絡(luò)中的nodes與傳統(tǒng)神經(jīng)網(wǎng)絡(luò)中的units不同。圖神經(jīng)網(wǎng)絡(luò)中的nodes是有狀態(tài)的(stateful),不像傳統(tǒng)神經(jīng)網(wǎng)絡(luò)中的units,當(dāng)一層計(jì)算完輸出給下一層后,這層units的生命就結(jié)束了。Nodes的狀態(tài)表示為一個(gè)向量,在下次迭代時(shí)會更新。此外,你也可以考慮為edges和global定義它們的狀態(tài)。
4. 圖神經(jīng)網(wǎng)絡(luò)的計(jì)算框架
① 初始步
-
初始化每個(gè)節(jié)點(diǎn)的狀態(tài)向量(可以包括各條邊和全局的狀態(tài))
② 消息傳遞(message-passing)迭代步:
-
計(jì)算節(jié)點(diǎn)到節(jié)點(diǎn)的消息向量
-
計(jì)算節(jié)點(diǎn)到節(jié)點(diǎn)的(多頭)注意力分布
-
對節(jié)點(diǎn)收到的消息進(jìn)行匯總計(jì)算
-
更新每個(gè)節(jié)點(diǎn)的狀態(tài)向量(可以包括各條邊和全局的狀態(tài))
5. 圖神經(jīng)網(wǎng)絡(luò)的計(jì)算復(fù)雜度
計(jì)算復(fù)雜度主要分為空間復(fù)雜度和時(shí)間復(fù)雜度。我們使用PyTorch或者TensorFlow進(jìn)行神經(jīng)網(wǎng)絡(luò)訓(xùn)練或預(yù)測時(shí),會遇到各種具體的復(fù)雜度,比如會有模型參數(shù)規(guī)模的復(fù)雜度,還有計(jì)算中產(chǎn)生中間tensors大小的復(fù)雜度,以及一次前向計(jì)算中需保存tensors個(gè)數(shù)的復(fù)雜度。我們訓(xùn)練神經(jīng)網(wǎng)絡(luò)時(shí),它做前向計(jì)算的過程中,由于梯度反向傳播的需要,前面層計(jì)算出的中間tensors要保留。但在預(yù)測階段,不需要梯度反向傳播,可以不保留中間產(chǎn)生的tensors,這會大大降低空間上的開銷。物理層面,我們現(xiàn)在用的GPU,一張卡的顯存頂?shù)教煲簿?4G,這個(gè)尺寸還是有限的,但是實(shí)際中遇到的很多圖都非常之大。另外,就是時(shí)間復(fù)雜度了。下面,我們用T表示一次圖計(jì)算中的迭代個(gè)數(shù),B表示輸入樣本的批大小(batch size),|V|表示節(jié)點(diǎn)個(gè)數(shù),|E|表示邊個(gè)數(shù),D,D1,D2表示表征向量的維數(shù)。
空間復(fù)雜度
-
模型參數(shù)規(guī)模
-
計(jì)算中間產(chǎn)生tensors規(guī)模(此時(shí)有B>=1, T=1)
-
計(jì)算中間保留tensors規(guī)模(此時(shí)有B>=1, T>=1)
時(shí)間復(fù)雜度
-
計(jì)算所需浮點(diǎn)數(shù)規(guī)模(此時(shí)考慮D1, D2)
總結(jié)復(fù)雜度的計(jì)算公式,不外乎如下的形式:
思路一:避開|E|
通常情況下,圖中邊的個(gè)數(shù)遠(yuǎn)大于節(jié)點(diǎn)的數(shù)量。極端情況下,當(dāng)邊的密度很高直至完全圖時(shí),圖的復(fù)雜度可以達(dá)到|V|(|V|-1)/2。如果考慮兩個(gè)節(jié)點(diǎn)間雙向的邊,以及節(jié)點(diǎn)到自身的特殊邊,那么這個(gè)復(fù)雜度就是|V|2。為了降低計(jì)算的復(fù)雜度,一個(gè)思路就是盡量避開圍繞邊的計(jì)算。具體來說,為了讓計(jì)算復(fù)雜度從|E|級別降低為|V|級別,在計(jì)算消息向量(message vectors)時(shí),我們僅計(jì)算 destination-independent messages。也就是說,從節(jié)點(diǎn)u發(fā)出的所有消息使用同一個(gè)向量,這樣復(fù)雜度從邊數(shù)級別降為了節(jié)點(diǎn)數(shù)級別。值得注意的是,這里會存在一個(gè)問題,消息向量里不區(qū)分不同的destination節(jié)點(diǎn)。那么,能否把不同的destination節(jié)點(diǎn)考慮進(jìn)來呢?當(dāng)然可以,不過需要引入multi-head attention機(jī)制。下面針對這種情況來介紹一下優(yōu)化方案。
適合情形
當(dāng)|E|>>|V|時(shí),即邊密度高的圖,尤其是完全圖
優(yōu)化方案
思路二:減少D
順著思路一,我們在計(jì)算attention時(shí),每個(gè)attention分?jǐn)?shù)都是一個(gè)標(biāo)量。我們可以減小計(jì)算attention所用的向量維數(shù),因?yàn)檩敵鍪且粋€(gè)標(biāo)量,信息被壓縮到一維空間,所以計(jì)算時(shí)沒必要使用大向量來提高capacity。如果需要multi-head的話,可以把每個(gè)計(jì)算channel的向量維數(shù)變小,讓它們加起來還等于原來的總維數(shù)。這個(gè)思路很像BERT,BERT雖然不是GNN,但是這種機(jī)制可以運(yùn)用到GNN中。還有一篇論文,提出了Graph Attention Networks,也用到了類似的思路。
適合情形
引入attention mechanism的multi-head channels設(shè)計(jì)
優(yōu)化方案
每個(gè)head channel 的消息計(jì)算使用較小的hidden dimensions, 通過增加head的數(shù)量來保證模型的capacity,而每個(gè)head的attention 分?jǐn)?shù)在一個(gè)節(jié)點(diǎn)上僅僅是一個(gè)標(biāo)量。
思路三:部分迭代更新(選擇性減少T)
前面的思路是減少邊數(shù)量以及計(jì)算維度數(shù),我們還可以減少迭代次數(shù)T,這樣中間需保留tensors的規(guī)模就會變小,適合非常大的網(wǎng)絡(luò),尤其當(dāng)網(wǎng)絡(luò)節(jié)點(diǎn)刻畫的時(shí)間跨度很大,或者異構(gòu)網(wǎng)絡(luò)的不同節(jié)點(diǎn)需要不同頻次或不同階段下的更新。有些節(jié)點(diǎn)不需要迭代更新那么多次,迭代兩、三次就夠了,有些節(jié)點(diǎn)要更新好多次才行。下圖的右側(cè)部分,每步迭代節(jié)點(diǎn)都更新;左側(cè)部分,節(jié)點(diǎn)只更新一次,即使這樣,它的計(jì)算依賴鏈條還是有四層。至于更新策略,可以人為設(shè)定,比如說,采取隨機(jī)抽樣方式,或者通過學(xué)習(xí)得到哪些節(jié)點(diǎn)需更新的更新策略。更新策略的數(shù)學(xué)實(shí)現(xiàn),可以采取hard gate的方式(注意不是soft),也可以采取sparse attention即選擇top-K節(jié)點(diǎn)的方式。有paper基于損失函數(shù)設(shè)計(jì)criteria去選擇更新的節(jié)點(diǎn),如果某個(gè)節(jié)點(diǎn)的當(dāng)前輸出對最終損失函數(shù)的貢獻(xiàn)已經(jīng)很好了,就不再更新。需要注意的是,在hard gate和sparse attention的代碼實(shí)現(xiàn)中,不能簡單地把要略過的節(jié)點(diǎn)的權(quán)重置零,雖然數(shù)學(xué)上等價(jià),但是CPU或GPU還是要計(jì)算的,所以代碼中需要實(shí)現(xiàn)稀疏性計(jì)算,來減少每次更新所載入的tensor規(guī)模。更新的粒度可以是逐點(diǎn)的,也可以是逐塊的。
適合情形
具有大時(shí)間跨度或異構(gòu)的網(wǎng)絡(luò),其節(jié)點(diǎn)需不同頻次或不同階段下的更新
優(yōu)化方案
更新策略一:預(yù)先設(shè)定每步更新節(jié)點(diǎn)
更新策略二:隨機(jī)抽樣每步更新節(jié)點(diǎn)
更新策略三:每步每節(jié)點(diǎn)通過hard gate的開關(guān)決定是否更新
更新策略四:每步通過sparse attention機(jī)制選擇top-K節(jié)點(diǎn)進(jìn)行更新
更新策略五:根據(jù)設(shè)定的criteria選擇更新節(jié)點(diǎn)(如:非shortcut支路上梯度趨零)
思路四:Baking(“烘焙”,即使用臨時(shí)memory存放某些計(jì)算結(jié)果)
Baking這個(gè)名字,是我引用計(jì)算機(jī)3D游戲設(shè)計(jì)中的一個(gè)名詞,來對深度學(xué)習(xí)中一種常見的技巧起的名字。當(dāng)某些數(shù)據(jù)的計(jì)算復(fù)雜度很高時(shí),我們可以提前算好它,后面需要時(shí)就直接拿來。這些數(shù)據(jù)通常需要一個(gè)臨時(shí)的記憶模塊來存儲。大時(shí)間跨度的早期計(jì)算節(jié)點(diǎn),或者異構(gòu)網(wǎng)絡(luò)的一些非重要節(jié)點(diǎn),我們假定它們對當(dāng)前計(jì)算的作用只是參考性的、非決定性的,并設(shè)計(jì)它們只參與前向計(jì)算,不參與梯度的反向傳播,此時(shí)我們可以使用記憶模塊保存這些算好的數(shù)據(jù)。記憶模塊的設(shè)計(jì),最簡單的就是一組向量,每個(gè)向量為一個(gè)記憶槽(slot),訪問過程可以是嚴(yán)格的索引匹配,或者采用soft attention機(jī)制。
適合情形
大時(shí)間跨度的早期計(jì)算節(jié)點(diǎn)或者異構(gòu)網(wǎng)絡(luò)的一些非重要節(jié)點(diǎn)(只參與前向計(jì)算,不參與梯度的反向傳播)。
優(yōu)化方案
維護(hù)一個(gè)記憶緩存,保存歷史計(jì)算的某些節(jié)點(diǎn)狀態(tài)向量,對緩存的訪問可以是嚴(yán)格索引匹配,也可以使用soft attention機(jī)制。
思路五:Distillation(蒸餾技術(shù))
蒸餾技術(shù)的應(yīng)用非常普遍。蒸餾的思想就是用層數(shù)更小的網(wǎng)絡(luò)來代替較重的大型網(wǎng)絡(luò)。實(shí)際上,所有神經(jīng)網(wǎng)絡(luò)的蒸餾思路都類似,只不過在圖神經(jīng)網(wǎng)絡(luò)里,要考慮如何把一個(gè)重型網(wǎng)絡(luò)壓縮成小網(wǎng)絡(luò)的具體細(xì)節(jié),包括要增加什么樣的loss來訓(xùn)練。這里,要明白蒸餾的目的不是僅僅為了學(xué)習(xí)到一個(gè)小網(wǎng)絡(luò),而是要讓學(xué)習(xí)出的小網(wǎng)絡(luò)可以很好地反映所給的重型網(wǎng)絡(luò)。小網(wǎng)絡(luò)相當(dāng)于重型網(wǎng)絡(luò)在低維空間的一個(gè)投影。實(shí)際上,用一個(gè)小的參數(shù)空間去錨定重型網(wǎng)絡(luò)的中間層features,基于hidden層或者attention層做對齊,盡量讓小網(wǎng)絡(luò)在某些中間層上產(chǎn)生與重型網(wǎng)絡(luò)相對接近的features。
適合情形
對已訓(xùn)練好的重型網(wǎng)絡(luò)進(jìn)行維度壓縮、層壓縮或稀疏性壓縮,讓中間層的feature space表達(dá)更緊湊。
優(yōu)化方案
Distillation Loss的設(shè)計(jì)方案:
-
Hidden-based loss
-
Attention-based loss
思路六:Partition (or clustering)
如果圖非常非常大,那該怎么辦?只能采取圖分割(graph partition)的方法了。我們可以借用傳統(tǒng)的圖分割或節(jié)點(diǎn)聚類算法,但是這些算法大多很耗時(shí),故不能采取過于復(fù)雜的圖分割或節(jié)點(diǎn)聚類算法。分割過程要注意執(zhí)行分割算法所用的節(jié)點(diǎn)數(shù)據(jù),最好不要直接在節(jié)點(diǎn)hidden features上做分割或聚類計(jì)算,這是因?yàn)橹挥衕idden features相似的nodes才會聚到一起,可能存在某些相關(guān)但hidden features不接近的節(jié)點(diǎn)需要放在一個(gè)組里。我們可以將hidden features做非線性轉(zhuǎn)換到某個(gè)分割語義下的空間,這個(gè)非線性轉(zhuǎn)換是帶參的,需要訓(xùn)練,即分割或聚類過程是學(xué)習(xí)得到的。每個(gè)分割后的組,組內(nèi)直接進(jìn)行節(jié)點(diǎn)到節(jié)點(diǎn)的消息傳遞,組間消息傳遞時(shí)先對一組節(jié)點(diǎn)做池化(pooling)計(jì)算,得到一個(gè)反映整個(gè)組的狀態(tài)向量,再通過這個(gè)向量與其他組的節(jié)點(diǎn)做消息傳遞。另外的關(guān)鍵一點(diǎn)是如何通過最終的損失函數(shù)來訓(xùn)練分割或聚類計(jì)算中的可訓(xùn)參數(shù)。我們可以把節(jié)點(diǎn)對組的成員關(guān)系(membership)引入到計(jì)算流程中,使得反向傳播時(shí)可以獲得相應(yīng)的梯度信息。當(dāng)然,如果不想這么復(fù)雜,你可以提前對圖做分割, 然后進(jìn)行消息傳遞。
適合情形
針對非常大的圖(尤其是完全圖)
優(yōu)化方案
對圖做快速分割處理,劃分節(jié)點(diǎn)成組,然后在組內(nèi)進(jìn)行節(jié)點(diǎn)到節(jié)點(diǎn)的消息傳遞,在組間進(jìn)行組到節(jié)點(diǎn)、或組到組的消息傳遞。
① Transformation step
-
Project hidden features onto the partition-oriented space
② Partitioning step
③ Group-pooling step
-
Compute group node states
④ Message-passing step
-
Compute messages from within-group neighbors
-
Compute messages from the current group node
-
Compute messages from other group nodes
思路七:稀疏圖計(jì)算
如何利用好稀疏圖把復(fù)雜度降下來?你不能把稀疏圖當(dāng)作dense矩陣來處理,并用Tensorflow或PyTorch做普通tensors間的計(jì)算,這是沒有效果的。你必須維護(hù)一個(gè)索引列表,而且這個(gè)索引列表支持快速的sort、unique、join等操作。舉個(gè)例子,你需要維護(hù)一份索引列表如下圖,第一列代表batch中每個(gè)sample的index,第二列代表source node的id。當(dāng)用節(jié)點(diǎn)狀態(tài)向量計(jì)算消息向量時(shí), 需要此索引列表與邊列表edgelist做join,把destination node的id引進(jìn)來,完成節(jié)點(diǎn)狀態(tài)向量到邊向量的轉(zhuǎn)換,然后你可以在邊向量上做一些計(jì)算,如經(jīng)過一兩層的小神經(jīng)網(wǎng)絡(luò),得到邊上的消息向量。得到消息向量后,對destination node做sort和unique操作。聯(lián)想稀疏矩陣的乘法計(jì)算,類似上述的過程,可以分成兩步,第一步是在非零元素上進(jìn)行element-wise乘操作,第二步是在列上做加操作。
適合情形
當(dāng)|E|<<|v|*|v|時(shí)
優(yōu)化方案
稀疏計(jì)算的關(guān)鍵在于維護(hù)一個(gè)索引列表,能快速進(jìn)行sort、unique、join操作并調(diào)用如下深度學(xué)習(xí)庫函數(shù):
TensorFlow:
- gather, gather_ndm
- scatter_nd, segment_sum,
- segment_max, unsored_segment_sum|max
Pytorch:
思路八:稀疏routing
稀疏routing與partition不同,partition需要將整個(gè)圖都考慮進(jìn)來,而稀疏routing只需考慮大圖中所用到的局部子圖。單個(gè)樣本每次計(jì)算時(shí),只需要用到大圖的一個(gè)局部子圖,剛開始的子圖可能僅是一個(gè)節(jié)點(diǎn)或幾個(gè)節(jié)點(diǎn),即聚焦在一個(gè)很小的區(qū)域,計(jì)算過程中聚焦區(qū)域逐漸擴(kuò)大。這種routing的方式也是一種attention機(jī)制,與傳統(tǒng)的attention機(jī)制有所不同。傳統(tǒng)的attention用于匯總各方來的消息向量,采用加權(quán)平均的方式,讓incoming消息的權(quán)重相加等于1;對于routing的話,剛好相反,讓outgoing的邊權(quán)重和為1,這個(gè)有點(diǎn)類似PageRank算法。這樣做的好處,可以在計(jì)算過程中通過選取top-K的outgoing邊來構(gòu)建一個(gè)動態(tài)剪枝的子圖。
適合情形
全圖雖大,但每次僅用到局部子圖
優(yōu)化方案
Attention機(jī)制是“拉”的模式,routing機(jī)制是“推”的模式。
思路九:跨樣本共享的圖特征
當(dāng)你計(jì)算的圖特征(如節(jié)點(diǎn)向量)不依賴具體樣本時(shí),這些特征可以作為輸入喂給每個(gè)樣本,但是它們的大小不隨batch size的大小而增加。我們稱這些是input-agnostic features,由于跨樣本共享,它們相當(dāng)于batch size為1的輸入。
適合情形
提供input-agnostic features
優(yōu)化方案
跨樣本共享,相當(dāng)于batch size為1。
思路十:組合使用以上九種方法
組合使用以上九種方法,根據(jù)自己的實(shí)際情況設(shè)計(jì)適當(dāng)?shù)乃惴ā?/span>
免責(zé)聲明:本文內(nèi)容由21ic獲得授權(quán)后發(fā)布,版權(quán)歸原作者所有,本平臺僅提供信息存儲服務(wù)。文章僅代表作者個(gè)人觀點(diǎn),不代表本平臺立場,如有問題,請聯(lián)系我們,謝謝!