OpenAI提出新的神經(jīng)網(wǎng)絡(luò)模型“稀疏Transformer”,能夠預(yù)測(cè)文本、圖像和聲音等序列的后續(xù)內(nèi)容,該模型是對(duì)注意力機(jī)制的一個(gè)改進(jìn),預(yù)測(cè)長(zhǎng)度達(dá)到之前最佳水平的30倍。
目前人工智能研究的一大挑戰(zhàn)是對(duì)復(fù)雜數(shù)據(jù)(如圖像,視頻或聲音)中的大范圍微妙的相互依賴性進(jìn)行建模。稀疏Transformer降低了傳統(tǒng)注意力機(jī)制模型的計(jì)算復(fù)雜度,將其直接應(yīng)用于不同的數(shù)據(jù)類型中。以前,在這些數(shù)據(jù)上使用的模型是針對(duì)某個(gè)專門領(lǐng)域設(shè)計(jì)的,難以擴(kuò)展到超過(guò)幾千個(gè)元素的序列規(guī)模上應(yīng)用。
此次OpenAI提出的模型可以使用數(shù)百個(gè)層對(duì)數(shù)萬(wàn)個(gè)元素的序列進(jìn)行建模,在多個(gè)域中實(shí)現(xiàn)最先進(jìn)的性能。稀疏Transformer能夠幫助我們構(gòu)建具有更強(qiáng)的理解世界能力的AI系統(tǒng)。
深度注意力機(jī)制
在稀疏Transformer中,每個(gè)輸出元素都與每個(gè)輸入元素相連,它們之間的權(quán)重是根據(jù)環(huán)境動(dòng)態(tài)計(jì)算的,這個(gè)過(guò)程稱為注意力。雖然這樣會(huì)讓模型比固定連接模式的模型更加靈活,但在實(shí)踐中需要為每個(gè)層和注意力頭N×N注意力矩陣,面對(duì)元素?cái)?shù)量眾多的數(shù)據(jù)類型時(shí)會(huì)消耗大量的內(nèi)存,比如圖像或原始音頻數(shù)據(jù)。
當(dāng)矩陣存儲(chǔ)在內(nèi)存中或在后向傳遞期間重新計(jì)算時(shí),深度Transformer的內(nèi)存消耗情況(64層、4個(gè)注意力頭)。作為參考,用于深度學(xué)習(xí)的標(biāo)準(zhǔn)GPU通常配備12-32GB的內(nèi)存
減少內(nèi)存消耗一種方法是在反向傳播期間從檢查點(diǎn)重新計(jì)算注意力矩陣,這是深度學(xué)習(xí)中的一種成熟技術(shù),以增加計(jì)算量為代價(jià)來(lái)減少內(nèi)存使用。在計(jì)算Transformer的注意力矩陣時(shí),意味著最大的內(nèi)存成本與層數(shù)無(wú)關(guān),這使我們能夠以比以前更大的深度訓(xùn)練神經(jīng)網(wǎng)絡(luò)。
實(shí)際上,我們發(fā)現(xiàn)深度達(dá)128層的Transformer在常用數(shù)據(jù)集基準(zhǔn)任務(wù)(如CIFAR-10)上的表現(xiàn)優(yōu)于較淺層的網(wǎng)絡(luò)。
為了更深入地訓(xùn)練這些模型,我們對(duì)Transformer中的操作順序進(jìn)行了幾次調(diào)整,并修改了初始方案。
稀疏注意力機(jī)制:顯著降低計(jì)算復(fù)雜度
然而,即使是計(jì)算單個(gè)注意力矩陣,對(duì)于非常大的輸入也是不切實(shí)際。因此我們使用稀疏注意力模式,即每個(gè)輸出位置僅計(jì)算來(lái)自輸入位置子集的權(quán)重。當(dāng)子集相對(duì)于整個(gè)輸入集較小時(shí),即使對(duì)于非常長(zhǎng)的序列,所得到的注意力計(jì)算也是容易處理的,算法復(fù)雜度為O(N *sqrt {N}),而不是O(N^2)。
為了評(píng)估該方法的可行性,我們首先將深度Transformer在圖像上的學(xué)習(xí)注意模式進(jìn)行可視化,發(fā)現(xiàn)許多模型表現(xiàn)出可解釋和結(jié)構(gòu)化的稀疏模式。下面的每個(gè)圖像顯示給定的注意頭處理哪些輸入像素(以白色突出顯示)以便預(yù)測(cè)圖像中的下一個(gè)值。
當(dāng)輸入部分聚焦在小的子集上并顯示出高度的規(guī)則性時(shí),該層就是易于稀疏化的。下圖為CIFAR-10圖像上的128層模型示例。
左圖為19層,右圖為20層
學(xué)習(xí)后的128層CIFAR-10網(wǎng)絡(luò)的多個(gè)層的注意力模式(白色高亮部分)。這些層學(xué)會(huì)將注意力分散在兩個(gè)維度上。其中第19層總結(jié)了每一行的信息,第20層則按列聚合這些信息,從而能夠?qū)θ孀⒁饬Σ僮鬟M(jìn)行有效分解。
左圖為第6層,右圖為第36層
一些層學(xué)會(huì)了訪問(wèn)位置存儲(chǔ)器,無(wú)論輸入數(shù)據(jù)或時(shí)間步長(zhǎng)如何,通常都會(huì)訪問(wèn)類似的位置(第6層)。還有的層學(xué)習(xí)了高度依賴數(shù)據(jù)的訪問(wèn)模式(第36層)。
雖然許多圖層顯示出了稀疏結(jié)構(gòu),某些層還清晰地顯示出在整個(gè)圖像上延伸的動(dòng)態(tài)注意力。為了讓網(wǎng)絡(luò)保持學(xué)習(xí)這些模式的能力,我們進(jìn)行了注意力矩陣的二維分解,網(wǎng)絡(luò)可以通過(guò)兩個(gè)稀疏注意力步驟來(lái)關(guān)注所有位置。
(左)普通transformer,(中)范圍注意力,(右)固定注意力
第一個(gè)版本,大范圍注意力,大致相當(dāng)于參與其行和列的每個(gè)位置,并且類似于上面的網(wǎng)絡(luò)學(xué)習(xí)的注意力模式。(注意,列注意力可以等效地表示成轉(zhuǎn)置矩陣的行注意力)。第二個(gè)版本是固定注意力,注意固定列和最新列元素之后的元素,我們發(fā)現(xiàn)這種模式在數(shù)據(jù)不適合二維結(jié)構(gòu)(如文本)時(shí)很有用。
實(shí)驗(yàn)結(jié)果:創(chuàng)造多個(gè)數(shù)據(jù)集上的新紀(jì)錄
稀疏Transformer在CIFAR-10,Enwik8和Imagenet 64上創(chuàng)造了密度估計(jì)的最新記錄。如下表所示:
CIFAR-10 | BITS PER DIM |
PixelCNN++ (Oord et al, 2016) | 2.92 |
Image Transformer (Parmar et. al, 2018) | 2.90 |
PixelSNAIL (Chen et al., 2017) | 2.85 |
Sparse Transformer 59M (256W, 128L, 2H) | 2.80 |
ENWIK8 | BITS PER BYTE |
Deeper Self-Attention (Al-Rfou et al, 2018) | 1.06 |
Transformer-XL 88M (Dai et al., 2018) | 1.03 |
Transformer-XL 277M (Dai et al., 2018) | 0.99 |
Sparse Transformer 95M (512W, 30L, 8H) | 0.99 |
IMAGENET 64X64 | BITS PER DIM |
PixelCNN++ (Oord et al, 2016) | 3.57 |
Parallel Multiscale (Reed et al, 2017) | 3.7 |
SPN 150M (Menick & Kalchbrenner, 2018) | 3.52 |
Sparse Transformer 152M (512W, 48L, 16H) | 3.44 |
在一系列數(shù)據(jù)集上的密度建模表現(xiàn),M為網(wǎng)絡(luò)中使用的參數(shù)數(shù)量(百萬(wàn)),W為網(wǎng)絡(luò)寬度,L為層數(shù),H為注意力頭數(shù)量。
我們還發(fā)現(xiàn),除了速度明顯更快之外,稀疏注意力模型的損失也要低于完全注意力模型。這可能表明我們的稀疏模式存在有用的歸納偏差,或是密集關(guān)注的潛在優(yōu)化問(wèn)題。
使用稀疏注意力的Transformer似乎有一個(gè)全局結(jié)構(gòu)的概念,可以通過(guò)查看圖像完成來(lái)定性評(píng)估。我們對(duì)64×64 ImageNet上訓(xùn)練的模型進(jìn)行了可視化,如下圖所示:
Prompt
Completions
Ground truth
我們還利用未調(diào)整的softmax temperature 1.0下生成了完全無(wú)條件的樣圖。這些模型使用最大似然目標(biāo)進(jìn)行訓(xùn)練,眾所周知,這類訓(xùn)練的目標(biāo)是覆蓋所有數(shù)據(jù)模式(包括可能不存在的數(shù)據(jù)),而不是增加小部分?jǐn)?shù)據(jù)的保真度。從這些具有未調(diào)整溫度的模型中生成樣圖,可以讓我們看到模型認(rèn)為存在于真實(shí)世界中圖像的完整分布。結(jié)果,一些樣本看起來(lái)很奇怪。
模型采樣
真實(shí)數(shù)據(jù)
生成原始音頻波形
稀疏Transformer也可以通過(guò)簡(jiǎn)單地改變位置嵌入,自適應(yīng)地生成原始音頻。隨著深度學(xué)習(xí)擴(kuò)展到新型數(shù)據(jù)類型,可以使用這類網(wǎng)絡(luò)作為確定歸納偏差的有用工具。
該模型在原始古典音樂(lè)剪輯上進(jìn)行訓(xùn)練,并使用稀疏注意力生成長(zhǎng)度為65000的序列,相當(dāng)于大約5秒的原始音頻,我們?cè)诿總€(gè)片段中將幾個(gè)樣本連接在了一起。
關(guān)于代碼發(fā)布和開(kāi)源
通常,實(shí)現(xiàn)稀疏注意力將涉及在數(shù)據(jù)塊中將查詢和關(guān)鍵矩陣單獨(dú)“切片”,因此為了簡(jiǎn)化實(shí)驗(yàn),我們實(shí)現(xiàn)了一組塊稀疏內(nèi)核,這些內(nèi)核可以在GPU上高效執(zhí)行這些操作。我們開(kāi)源了這些內(nèi)核,并在Github上提供示例稀疏注意函數(shù)。
未來(lái)方向和局限
我們提出的稀疏注意力模式只是長(zhǎng)序列高效建模方向的初步模式。我們認(rèn)為,探索稀疏性的不同模式和組合的用途不僅于此,學(xué)習(xí)稀疏模式對(duì)于下一代神經(jīng)網(wǎng)絡(luò)體系結(jié)構(gòu)來(lái)說(shuō)是一個(gè)很有前途的方向。
即使經(jīng)過(guò)改進(jìn),自回歸序列生成對(duì)于非常高分辨率的圖像或視頻來(lái)說(shuō)仍然是不切實(shí)際的。不過(guò),我們提出的優(yōu)化注意力操作可能是一次有益的探索,可以和其他(如多尺度方法)方法相結(jié)合來(lái)對(duì)高維數(shù)據(jù)進(jìn)行建模。
-
神經(jīng)網(wǎng)絡(luò)
+關(guān)注
關(guān)注
42文章
4773瀏覽量
100877 -
圖像
+關(guān)注
關(guān)注
2文章
1086瀏覽量
40496 -
模型
+關(guān)注
關(guān)注
1文章
3255瀏覽量
48904
原文標(biāo)題:OpenAI提出Sparse Transformer,文本、圖像、聲音都能預(yù)測(cè),序列長(zhǎng)度提高30倍
文章出處:【微信號(hào):AI_era,微信公眾號(hào):新智元】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
評(píng)論