導(dǎo)讀
這項(xiàng)工作旨在提高視覺(jué)Transformer(ViT)的效率。雖然ViT在每一層中使用計(jì)算代價(jià)高昂的自注意力操作,但我們發(fā)現(xiàn)這些操作在層之間高度相關(guān)——這會(huì)導(dǎo)致產(chǎn)生很多不必要計(jì)算的冗余信息。基于這一觀察,我們提出了SKIPAT方法,該方法利用前面層的自注意力計(jì)算來(lái)近似在一個(gè)或多個(gè)后續(xù)層的注意力。為了確保在層之間重用自注意力塊而不降低性能,我們引入了一個(gè)簡(jiǎn)單的參數(shù)函數(shù),該函數(shù)在計(jì)算速度更快的情況下能表現(xiàn)出優(yōu)于基準(zhǔn)Transformer的性能。我們?cè)趫D像分類和ImageNet-1K上的自我監(jiān)督學(xué)習(xí)、ADE20K上的語(yǔ)義分割、SIDD上的圖像去噪以及DAVIS上的視頻去噪中展示了我們方法的有效性。我們?cè)谒羞@些任務(wù)中都在相同或更高的準(zhǔn)確度水平下實(shí)現(xiàn)了提高模型吞吐量。
背景
Performance of SKIPAT across 5 different tasks.
Transformer架構(gòu)已經(jīng)成為一個(gè)重要且影響深遠(yuǎn)的模型系列,因?yàn)樗?jiǎn)單、可擴(kuò)展,并且應(yīng)用廣泛。雖然最初來(lái)自自然語(yǔ)言處理(NLP)領(lǐng)域,但隨著視覺(jué)transformer(ViT)的出現(xiàn),這已成為計(jì)算機(jī)視覺(jué)領(lǐng)域的標(biāo)準(zhǔn)架構(gòu),在從表示學(xué)習(xí)、語(yǔ)義分割、目標(biāo)檢測(cè)到視頻理解等任務(wù)中獲得了各種最先進(jìn)(SoTA)性能。
然而,transformer的原始公式在輸入令牌(token)數(shù)量方面具有二次計(jì)算復(fù)雜度。鑒于這個(gè)數(shù)字通常從圖像分類的14^2到圖像去噪的128^2 = 16K不等,內(nèi)存和計(jì)算的這一限制嚴(yán)重限制了它的適用性。目前有三組方法來(lái)解決這個(gè)問(wèn)題:第一組利用輸入令牌之間的冗余,并通過(guò)高效的抽樣簡(jiǎn)單地減少計(jì)算,例如丟棄或合并冗余令牌。然而,這意味著ViT的最終輸出不是空間連續(xù)的,因此不能超出圖像級(jí)別(image-level)的應(yīng)用,如語(yǔ)義分割或目標(biāo)檢測(cè)。第二組方法旨在以低成本計(jì)算近似注意力,但通常以性能降低為代價(jià)。最后,另一組工作旨在將卷積架構(gòu)與transformer合并,產(chǎn)生混合架構(gòu)。雖然這些方法提高了速度,但它們并沒(méi)有解決二次復(fù)雜度的基本問(wèn)題,并且通常會(huì)引入過(guò)多的設(shè)計(jì)選擇(基本上是transformer和CNN的聯(lián)合)。
在這項(xiàng)工作中,我們提出了一種新穎的、迄今為止未經(jīng)探索的方法:利用計(jì)算速度快且簡(jiǎn)單的參數(shù)函數(shù)來(lái)逼近transformer的計(jì)算代價(jià)高的塊。為了得出這個(gè)解決方案,我們首先詳細(xì)地分析了ViT的關(guān)鍵多頭自注意力(MSA)塊。通過(guò)這項(xiàng)分析,我們發(fā)現(xiàn)CLS令牌對(duì)空間塊的注意力在transformer的塊之間具有非常高的相關(guān)性,從而導(dǎo)致許多不必要的計(jì)算。這啟發(fā)了我們的方法利用模型早期的注意力,并將其簡(jiǎn)單地重用于更深的塊——基本上是“跳過(guò)”后續(xù)的SA計(jì)算,而不是在每一層重新計(jì)算它們。
基于此,我們進(jìn)一步探索是否可以通過(guò)重用前面層的表示來(lái)跳過(guò)整一層的MSA塊。受ResneXt的深度卷積的啟發(fā),我們發(fā)現(xiàn)一個(gè)簡(jiǎn)單的參數(shù)函數(shù)可以優(yōu)于基準(zhǔn)模型性能——在吞吐量和FLOPs的計(jì)算速度方面更快。我們的方法是通用的,可以應(yīng)用于任何上下文的ViT:上圖顯示,我們的跳過(guò)注意力(SKIPAT)的新型參數(shù)函數(shù)在各種任務(wù)、數(shù)據(jù)集和模型大小上都能實(shí)現(xiàn)與基準(zhǔn)transformer相比更優(yōu)的精度與效率。
綜上所述,我們的貢獻(xiàn)如下所示:
我們提出了一種新型的插件模塊,可以放在任何ViT架構(gòu)中,以減少昂貴的O(n^2)自注意力計(jì)算復(fù)雜度。
我們?cè)贗mageNet、Pascal-VOC2012、SIDD、DAVIS和ADE20K數(shù)據(jù)集上實(shí)現(xiàn)了在吞吐量指標(biāo)上的最SOTA性能,并獲得了同等或更高的準(zhǔn)確度。
我們的方法在沒(méi)有下游準(zhǔn)確度損失的情況下,自監(jiān)督預(yù)訓(xùn)練時(shí)間能減少26%,并且在移動(dòng)設(shè)備上展示了優(yōu)越的延遲,這都證明了我們方法的普適性。
我們分析了性能提升的來(lái)源,并對(duì)我們的方法進(jìn)行了大量的實(shí)驗(yàn)分析,為提供可用于權(quán)衡準(zhǔn)確度和吞吐量的模型系列提供了支持。
方法
SKIPAT framework.
引言
Vision Transformer
設(shè)x ∈ R^(h×w×c) 為一張輸入圖像,其中h × w是空間分辨率,c是通道數(shù)。首先將圖像分成n = hw/p^2個(gè)不重疊的塊,其中p × p是塊大小。使用線性層將每個(gè)塊投影到一個(gè)embedding zi ∈ R^d 中,從而得到分塊的圖像:
Transformer Layer
Transformer的每一層由多頭自注意力(MSA)塊和多層感知機(jī)(MLP)塊組成。在MSA塊中,Zl?1 ∈ R^(n×d),首先被投影到三個(gè)可學(xué)習(xí)embeddings {Q, K, V } ∈ R^(n×d)中。注意力矩陣A的計(jì)算公式如下:
MSA中的“多頭”是指考慮h個(gè)注意力頭,其中每個(gè)頭是一個(gè)n × d/h 矩陣的序列。使用線性層將注意頭重新投影回n × d,并與值矩陣結(jié)合,公式如下所示:
然后,將MSA塊的輸出表示輸入到MLP塊,該塊包括兩個(gè)由GeLU激活分隔的線性層。在給定層l處,表示通過(guò)transformer塊的計(jì)算流程如下:
MSA和MLP塊都具有帶層正則化(LN)的殘差連接。雖然transformer的每一層中的MSA塊均是學(xué)習(xí)互不依賴的表示,但在下一小節(jié)中,我們將展示這些跨層間存在高度相關(guān)性。
啟發(fā): 層相關(guān)性分析
Attention correlation.
ViT中的MSA塊將每個(gè)塊與每個(gè)其他塊的相似性編碼為n × n注意力矩陣。這個(gè)運(yùn)算符具有O(n^2)復(fù)雜度(公式2)的計(jì)算成本。隨著ViT的擴(kuò)展,即隨著n的增加,計(jì)算復(fù)雜度呈二次增長(zhǎng),使得這個(gè)操作成為性能瓶頸。最近的NLP工作表明,SoTA語(yǔ)言模型中相鄰層之間的自注意力具有非常高的相關(guān)性。這引發(fā)了一個(gè)問(wèn)題 -在視覺(jué)transformer是否真的需要每一層都計(jì)算自注意力?
CKA analysis of A^[CLS] and Z^MSA across different layers of pretrained ViT-T/16.
為了回答這個(gè)問(wèn)題,我們分析了ViT不同層之間自注意力圖的相關(guān)性。如本節(jié)圖1所示,來(lái)自類別token的自注意力圖A^[CLS]在中間層特別具有高度相關(guān)性。A^[CLS]l?1和A^[CLS]l 之間的余弦相似度可以高達(dá)0.97。其他token embeddings 也表現(xiàn)出類似的行為。我們通過(guò)計(jì)算每對(duì)i,j∈L的A^[CLS]i和A^[CLS]j之間的Centered Kernel Alignment(CKA)來(lái)定量分析ImageNet-1K驗(yàn)證集的所有樣本之間的相關(guān)性。CKA度量網(wǎng)絡(luò)中間層獲得的表示之間的相似性,其中CKA的值越高則表示它們之間的相關(guān)性越高。從本節(jié)圖2中,我們發(fā)現(xiàn)ViT-T在A^[CLS]之間具有高度性,特別是第三層到第十層。
Feature correlation
在ViT中,高相關(guān)性不僅局限于A^[CLS],MSA塊的表示Z^MSA也在整個(gè)模型中顯示出高度相關(guān)性。為了分析這些表示之間的相似性,我們計(jì)算每對(duì)i,j∈L的Z^MSAi和Z^MSAj之間的CKA。我們從從本節(jié)圖2中觀察到,Z^MSA在模型的相鄰層之間也具有很高的相似性,特別是在較早的層,即從第2層到第8層。
利用 Skipping Attention 提升效率
基于我們對(duì)transformer中MSA不同塊之間具有高度相似性的觀察,我們建議利用注意力矩陣和MSA塊的表示之間的相關(guān)性來(lái)提高視覺(jué)transformer的效率。與在每層單獨(dú)計(jì)算MSA操作(公式3)相反,我們探索了一種利用不同層之間依賴關(guān)系的簡(jiǎn)單且有效的策略。
我們建議通過(guò)重用其相鄰層的特征表示來(lái)跳過(guò)transformer的一個(gè)或多個(gè)層中的MSA計(jì)算。我們將此操作稱為Skip Attention(SKIPAT)。由于跳過(guò)整個(gè)MSA塊的計(jì)算和內(nèi)存效益大于僅跳過(guò)自注意力操作 O(n^2d+nd^2) vs. O(n^2d),因此在本文中我們主要關(guān)注前者。我們引入了一個(gè)參數(shù)函數(shù),而不是直接重用特征,換句話說(shuō),就是將來(lái)源MSA塊的特征復(fù)制到一個(gè)或多個(gè)相鄰MSA塊。參數(shù)函數(shù)確保直接重用特征不會(huì)影響這些MSA塊中的平移不變性和等價(jià)性,并充當(dāng)強(qiáng)大的正則化器以提高模型泛化性。
SKIPAT parametric function
設(shè) Φ:R^(n×d) → R^(n×d)表示將l?1層的MSA塊映射到l層的參數(shù)函數(shù),作為Z?^MSA l:=Φ(Z^MSA l?1)。在這里,Z?^MSA l是Z^MSA l的近似值。參數(shù)函數(shù)可以是簡(jiǎn)單的單位函數(shù),其中Z^MSA l?1能被直接重用。我們使用Z^MSA l?1作為l處的MLP塊的輸入,而不是在l處計(jì)算MSA操作。當(dāng)使用單位函數(shù)時(shí),由于l處沒(méi)有MSA操作,因此在注意力矩陣中的token間關(guān)系不再被編碼,這會(huì)影響表示學(xué)習(xí)。為了減輕這一點(diǎn),我們引入了SKIPAT參數(shù)函數(shù),用于對(duì)token之間的局部關(guān)系進(jìn)行編碼。SKIPAT參數(shù)函數(shù)由兩個(gè)線性層和中間的深度卷積(DwC)組成,計(jì)算公式如下所示:
SKIPAT framework
SKIPAT 是一種可以被納入任何 transformer 架構(gòu)的框架,我們通過(guò)大量實(shí)驗(yàn)對(duì)比結(jié)果充分地證明了這一點(diǎn)。根據(jù)架構(gòu)的不同,可以在 transformer 的一層或多層中跳過(guò) MSA 操作。在 ViT 中,我們觀察到來(lái)自 MSA 塊(Z^MSA )的表示在第 2 層到第 7 層之間有很高的相關(guān)性,所以我們?cè)谶@些層中使用 SKIPAT 參數(shù)函數(shù)。這意味著我們將 Z^MSA2 作為輸入傳遞給 SKIPAT 參數(shù)函數(shù),并在 3-8 層中跳過(guò) MSA 操作。相反,來(lái)自 SKIPAT 參數(shù)函數(shù)輸出的特征被用作 MLP 塊的輸入。表示的計(jì)算流現(xiàn)在被修改為:
由于 MSA 和 MLP 塊中存在殘留連接,第 3 層到第 8 層的 MLP 塊需要獨(dú)立地學(xué)習(xí)表示,不能從計(jì)算圖中刪除。值得注意的是,使用 SKIPAT 后 ViT 的總層數(shù)不變,但 MSA 塊的數(shù)量減少了。
Complexity: MSA vs. SKIPAT
自注意力操作包括三個(gè)步驟。首先,將token embeddings 投射到query、key和value embeddings,其次,計(jì)算注意力矩陣 A,它是 Q 和 K 的點(diǎn)積,最后,計(jì)算輸出表示作為 A 和 V 的點(diǎn)積。這導(dǎo)致了計(jì)算復(fù)雜度為 O(4nd^2 + n^2d)。由于 d ? n,所以 MSA 塊的復(fù)雜度可以降低到 O(n^2d)。
SKIPAT 參數(shù)函數(shù)由兩個(gè)線性層和一個(gè)深度卷積操作組成,計(jì)算復(fù)雜度為 O(2nd^2 + r^2nd),其中 r × r 是 DwC 操作的內(nèi)核大小。由于 r^2 ? d,所以 SKIPAT 的整體復(fù)雜度可以降低到 O(nd^2)。因此,當(dāng) n 隨著 transformer 的擴(kuò)大而增加時(shí),SKIPAT 的 FLOPs值 比 MSA 塊更少,即 O(nd^2) < O(n^2d)。
實(shí)驗(yàn)
上圖展示的是分割mask的可視化效果:第一行和第二行分別是原始Vit-S模型和Vit-S + SKIPAT模型。顯而易見(jiàn),Vit-S + SKIPAT模型對(duì)圖像中前景和背景的區(qū)分度顯著高于原始Vit-S模型。
上圖展示的是注意力圖的可視化效果:對(duì)比原始Vit-S模型(baseline),Vit-S + SKIPAT模型對(duì)目標(biāo)的定位能力有明顯提升。
上圖展示的是特征圖和Z^MSA的相關(guān)性:從中可以清晰地觀察到在大多數(shù)不同層之間Z^MSA僅有較低的相關(guān)性。
圖象分類
Image classification on ImageNet-1K.
自監(jiān)督
Unsupervised Segmentation and Object Localization on the validation set of Pascal VOC2012.
推理性能
On-device latency (in msec) of vanilla ViT vs. SKIPAT.
語(yǔ)義分割
Semantic Segmentation results on ADE20K.
圖像去噪
Image denoising on SIDD dataset using PSNR andSSIM as the evaluation metrics in the RGB space.
總結(jié)
我們提出了一種可以在任何 ViT 架構(gòu)中即插即用的模塊 SKIPAT,用于減少昂貴的自注意力計(jì)算。SKIPAT 利用 MSA 塊之間的依賴性,并通過(guò)重用以前 MSA 塊的注意力表示來(lái)繞過(guò)注意力計(jì)算。此外,我們引入了一個(gè)簡(jiǎn)單且輕量的參數(shù)函數(shù),它不會(huì)影響 MSA 中編碼的歸納偏見(jiàn)。SKIPAT 函數(shù)能夠捕獲跨token之間的關(guān)系,在吞吐量和 FLOPs 指標(biāo)上優(yōu)于基線模型,同時(shí)我們?cè)? 種不同的任務(wù)中充分地表現(xiàn)出SKIPAT的有效性。
審核編輯 :李倩
-
編碼
+關(guān)注
關(guān)注
6文章
962瀏覽量
55307 -
Transformer
+關(guān)注
關(guān)注
0文章
147瀏覽量
6301 -
自然語(yǔ)言處理
+關(guān)注
關(guān)注
1文章
625瀏覽量
13904
原文標(biāo)題:即插即用!Skip-Attention:一種顯著降低Transformer計(jì)算量的輕量化方法
文章出處:【微信號(hào):CVer,微信公眾號(hào):CVer】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
知行科技將與智駕大陸攜手打造輕量化城區(qū)領(lǐng)航產(chǎn)品
引領(lǐng)輕量化趨勢(shì)| 法法易輕量化充電槍通過(guò)2023版標(biāo)準(zhǔn)強(qiáng)檢測(cè)試

守護(hù)公路安全! 中海達(dá)推出輕量化監(jiān)測(cè)簡(jiǎn)易感知方案

一種信息引導(dǎo)的量化后LLM微調(diào)新算法IR-QLoRA

中海達(dá)推出輕量化監(jiān)測(cè)簡(jiǎn)易感知解決方案
自動(dòng)駕駛中一直說(shuō)的BEV+Transformer到底是個(gè)啥?

5G RedCap:輕量化5G技術(shù)引領(lǐng)物聯(lián)網(wǎng)新未來(lái)

5G輕量化網(wǎng)關(guān)是什么

輕量化IP制作與傳輸?shù)淖兏镏?千視Judy專訪

深度神經(jīng)網(wǎng)絡(luò)模型量化的基本方法
深度學(xué)習(xí)模型量化方法

評(píng)論