深度神經(jīng)網(wǎng)絡(luò)模型訓(xùn)練之難眾所周知,其中一個(gè)重要的現(xiàn)象就是 Internal Covariate Shift. Batch Norm 大法自 2015 年由Google 提出之后,就成為深度學(xué)習(xí)必備之神器。自 BN 之后, Layer Norm / Weight Norm / Cosine Norm 等也橫空出世。本文從 Normalization 的背景講起,用一個(gè)公式概括 Normalization 的基本思想與通用框架,將各大主流方法一一對號(hào)入座進(jìn)行深入的對比分析,并從參數(shù)和數(shù)據(jù)的伸縮不變性的角度探討 Normalization 有效的深層原因。
作者:Juliuszh,PhD 畢業(yè)于 THU 計(jì)算機(jī)系。現(xiàn)在 Tencent AI Lab 從事機(jī)器學(xué)習(xí)和個(gè)性化推薦研究與 AI 平臺(tái)開發(fā)工作。
來源:機(jī)器學(xué)習(xí)札記知乎專欄
目錄:
1. 為什么需要 Normalization
——深度學(xué)習(xí)中的 Internal Covariate Shift 問題及其影響
2. Normalization 的通用框架與基本思想
——從主流 Normalization 方法中提煉出的抽象框架
3. 主流 Normalization 方法梳理
——結(jié)合上述框架,將 BatchNorm / LayerNorm / WeightNorm / CosineNorm 對號(hào)入座,各種方法之間的異同水落石出。
4. Normalization 為什么會(huì)有效?
——從參數(shù)和數(shù)據(jù)的伸縮不變性探討Normalization有效的深層原因。
以下是正文,enjoy.
1. 為什么需要 Normalization
1.1 獨(dú)立同分布與白化
機(jī)器學(xué)習(xí)界的煉丹師們最喜歡的數(shù)據(jù)有什么特點(diǎn)?竊以為,莫過于“獨(dú)立同分布”了,即independent and identically distributed,簡稱為 i.i.d. 獨(dú)立同分布并非所有機(jī)器學(xué)習(xí)模型的必然要求(比如 Naive Bayes 模型就建立在特征彼此獨(dú)立的基礎(chǔ)之上,而Logistic Regression 和 神經(jīng)網(wǎng)絡(luò) 則在非獨(dú)立的特征數(shù)據(jù)上依然可以訓(xùn)練出很好的模型),但獨(dú)立同分布的數(shù)據(jù)可以簡化常規(guī)機(jī)器學(xué)習(xí)模型的訓(xùn)練、提升機(jī)器學(xué)習(xí)模型的預(yù)測能力,已經(jīng)是一個(gè)共識(shí)。
因此,在把數(shù)據(jù)喂給機(jī)器學(xué)習(xí)模型之前,“白化(whitening)”是一個(gè)重要的數(shù)據(jù)預(yù)處理步驟。白化一般包含兩個(gè)目的:
(1)去除特征之間的相關(guān)性 —> 獨(dú)立;
(2)使得所有特征具有相同的均值和方差 —> 同分布。
白化最典型的方法就是PCA,可以參考閱讀 PCAWhitening。
1.2 深度學(xué)習(xí)中的 Internal Covariate Shift
深度神經(jīng)網(wǎng)絡(luò)模型的訓(xùn)練為什么會(huì)很困難?其中一個(gè)重要的原因是,深度神經(jīng)網(wǎng)絡(luò)涉及到很多層的疊加,而每一層的參數(shù)更新會(huì)導(dǎo)致上層的輸入數(shù)據(jù)分布發(fā)生變化,通過層層疊加,高層的輸入分布變化會(huì)非常劇烈,這就使得高層需要不斷去重新適應(yīng)底層的參數(shù)更新。為了訓(xùn)好模型,我們需要非常謹(jǐn)慎地去設(shè)定學(xué)習(xí)率、初始化權(quán)重、以及盡可能細(xì)致的參數(shù)更新策略。
Google 將這一現(xiàn)象總結(jié)為 Internal Covariate Shift,簡稱 ICS. 什么是 ICS 呢?@魏秀參在一個(gè)回答中做出了一個(gè)很好的解釋:
大家都知道在統(tǒng)計(jì)機(jī)器學(xué)習(xí)中的一個(gè)經(jīng)典假設(shè)是“源空間(source domain)和目標(biāo)空間(target domain)的數(shù)據(jù)分布(distribution)是一致的”。如果不一致,那么就出現(xiàn)了新的機(jī)器學(xué)習(xí)問題,如 transfer learning / domain adaptation 等。而 covariate shift 就是分布不一致假設(shè)之下的一個(gè)分支問題,它是指源空間和目標(biāo)空間的條件概率是一致的,但是其邊緣概率不同,即:
但是大家細(xì)想便會(huì)發(fā)現(xiàn),的確,對于神經(jīng)網(wǎng)絡(luò)的各層輸出,由于它們經(jīng)過了層內(nèi)操作作用,其分布顯然與各層對應(yīng)的輸入信號(hào)分布不同,而且差異會(huì)隨著網(wǎng)絡(luò)深度增大而增大,可是它們所能“指示”的樣本標(biāo)記(label)仍然是不變的,這便符合了covariate shift的定義。由于是對層間信號(hào)的分析,也即是“internal”的來由。
1.3 ICS 會(huì)導(dǎo)致什么問題?
簡而言之,每個(gè)神經(jīng)元的輸入數(shù)據(jù)不再是“獨(dú)立同分布”。
其一,上層參數(shù)需要不斷適應(yīng)新的輸入數(shù)據(jù)分布,降低學(xué)習(xí)速度。
其二,下層輸入的變化可能趨向于變大或者變小,導(dǎo)致上層落入飽和區(qū),使得學(xué)習(xí)過早停止。
其三,每層的更新都會(huì)影響到其它層,因此每層的參數(shù)更新策略需要盡可能的謹(jǐn)慎。
2. Normalization 的通用框架與基本思想
我們以神經(jīng)網(wǎng)絡(luò)中的一個(gè)普通神經(jīng)元為例。神經(jīng)元接收一組輸入向量
$${x}=(x_1, x_2, /cdots, x_d)$$
通過某種運(yùn)算后,輸出一個(gè)標(biāo)量值:
$$y=f({x})$$
由于 ICS 問題的存在,x的分布可能相差很大。要解決獨(dú)立同分布的問題,“理論正確”的方法就是對每一層的數(shù)據(jù)都進(jìn)行白化操作。然而標(biāo)準(zhǔn)的白化操作代價(jià)高昂,特別是我們還希望白化操作是可微的,保證白化操作可以通過反向傳播來更新梯度。
因此,以 BN 為代表的 Normalization 方法退而求其次,進(jìn)行了簡化的白化操作。基本思想是:在將x送給神經(jīng)元之前,先對其做平移和伸縮變換, 將x的分布規(guī)范化成在固定區(qū)間范圍的標(biāo)準(zhǔn)分布。
通用變換框架就如下所示:
我們來看看這個(gè)公式中的各個(gè)參數(shù)。
奇不奇怪?奇不奇怪?
說好的處理 ICS,第一步都已經(jīng)得到了標(biāo)準(zhǔn)分布,第二步怎么又給變走了?
答案是——為了保證模型的表達(dá)能力不因?yàn)橐?guī)范化而下降。
我們可以看到,第一步的變換將輸入數(shù)據(jù)限制到了一個(gè)全局統(tǒng)一的確定范圍(均值為 0、方差為 1)。下層神經(jīng)元可能很努力地在學(xué)習(xí),但不論其如何變化,其輸出的結(jié)果在交給上層神經(jīng)元進(jìn)行處理之前,將被粗暴地重新調(diào)整到這一固定范圍。
沮不沮喪?沮不沮喪?
難道我們底層神經(jīng)元人民就在做無用功嗎?
所以,為了尊重底層神經(jīng)網(wǎng)絡(luò)的學(xué)習(xí)結(jié)果,我們將規(guī)范化后的數(shù)據(jù)進(jìn)行再平移和再縮放,使得每個(gè)神經(jīng)元對應(yīng)的輸入范圍是針對該神經(jīng)元量身定制的一個(gè)確定范圍(均值為b、方差為g2)。rescale 和 reshift 的參數(shù)都是可學(xué)習(xí)的,這就使得 Normalization 層可以學(xué)習(xí)如何去尊重底層的學(xué)習(xí)結(jié)果。
除了充分利用底層學(xué)習(xí)的能力,另一方面的重要意義在于保證獲得非線性的表達(dá)能力。Sigmoid 等激活函數(shù)在神經(jīng)網(wǎng)絡(luò)中有著重要作用,通過區(qū)分飽和區(qū)和非飽和區(qū),使得神經(jīng)網(wǎng)絡(luò)的數(shù)據(jù)變換具有了非線性計(jì)算能力。而第一步的規(guī)范化會(huì)將幾乎所有數(shù)據(jù)映射到激活函數(shù)的非飽和區(qū)(線性區(qū)),僅利用到了線性變化能力,從而降低了神經(jīng)網(wǎng)絡(luò)的表達(dá)能力。而進(jìn)行再變換,則可以將數(shù)據(jù)從線性區(qū)變換到非線性區(qū),恢復(fù)模型的表達(dá)能力。
那么問題又來了——
經(jīng)過這么的變回來再變過去,會(huì)不會(huì)跟沒變一樣?
那么還有一個(gè)問題——
這樣的 Normalization 離標(biāo)準(zhǔn)的白化還有多遠(yuǎn)?
標(biāo)準(zhǔn)白化操作的目的是“獨(dú)立同分布”。獨(dú)立就不說了,暫不考慮。變換為均值為b、方差為g2(g的平方)的分布,也并不是嚴(yán)格的同分布,只是映射到了一個(gè)確定的區(qū)間范圍而已。(所以,這個(gè)坑還有得研究呢!)
3. 主流 Normalization 方法梳理
在上一節(jié)中,我們提煉了 Normalization 的通用公式:
對照于這一公式,我們來梳理主流的四種規(guī)范化方法。
3.1 Batch Normalization —— 縱向規(guī)范化
Batch Normalization 于2015年由 Google 提出,開 Normalization 之先河。其規(guī)范化針對單個(gè)神經(jīng)元進(jìn)行,利用網(wǎng)絡(luò)訓(xùn)練時(shí)一個(gè) mini-batch 的數(shù)據(jù)來計(jì)算該神經(jīng)元Xi的均值和方差,因而稱為 Batch Normalization。
其中M是 mini-batch 的大小。
按上圖所示,相對于一層神經(jīng)元的水平排列,BN 可以看做一種縱向的規(guī)范化。由于 BN 是針對單個(gè)維度定義的,因此標(biāo)準(zhǔn)公式中的計(jì)算均為 element-wise 的。
BN 獨(dú)立地規(guī)范化每一個(gè)輸入維度Xi,但規(guī)范化的參數(shù)是一個(gè) mini-batch 的一階統(tǒng)計(jì)量和二階統(tǒng)計(jì)量。這就要求 每一個(gè) mini-batch 的統(tǒng)計(jì)量是整體統(tǒng)計(jì)量的近似估計(jì),或者說每一個(gè) mini-batch 彼此之間,以及和整體數(shù)據(jù),都應(yīng)該是近似同分布的。分布差距較小的 mini-batch 可以看做是為規(guī)范化操作和模型訓(xùn)練引入了噪聲,可以增加模型的魯棒性;但如果每個(gè) mini-batch的原始分布差別很大,那么不同 mini-batch 的數(shù)據(jù)將會(huì)進(jìn)行不一樣的數(shù)據(jù)變換,這就增加了模型訓(xùn)練的難度。
因此,BN 比較適用的場景是:每個(gè) mini-batch 比較大,數(shù)據(jù)分布比較接近。在進(jìn)行訓(xùn)練之前,要做好充分的 shuffle. 否則效果會(huì)差很多。
另外,由于 BN 需要在運(yùn)行過程中統(tǒng)計(jì)每個(gè) mini-batch 的一階統(tǒng)計(jì)量和二階統(tǒng)計(jì)量,因此不適用于 動(dòng)態(tài)的網(wǎng)絡(luò)結(jié)構(gòu) 和 RNN 網(wǎng)絡(luò)。不過,也有研究者專門提出了適用于 RNN 的 BN 使用方法,這里先不展開了。
3.2 Layer Normalization —— 橫向規(guī)范化
層規(guī)范化就是針對 BN 的上述不足而提出的。與 BN 不同,LN 是一種橫向的規(guī)范化,如圖所示。它綜合考慮一層所有維度的輸入,計(jì)算該層的平均輸入值和輸入方差,然后用同一個(gè)規(guī)范化操作來轉(zhuǎn)換各個(gè)維度的輸入。
LN 針對單個(gè)訓(xùn)練樣本進(jìn)行,不依賴于其他數(shù)據(jù),因此可以避免 BN 中受 mini-batch 數(shù)據(jù)分布影響的問題,可以用于 小mini-batch場景、動(dòng)態(tài)網(wǎng)絡(luò)場景和 RNN,特別是自然語言處理領(lǐng)域。此外,LN 不需要保存 mini-batch 的均值和方差,節(jié)省了額外的存儲(chǔ)空間。
但是,BN 的轉(zhuǎn)換是針對單個(gè)神經(jīng)元可訓(xùn)練的——不同神經(jīng)元的輸入經(jīng)過再平移和再縮放后分布在不同的區(qū)間,而 LN 對于一整層的神經(jīng)元訓(xùn)練得到同一個(gè)轉(zhuǎn)換——所有的輸入都在同一個(gè)區(qū)間范圍內(nèi)。如果不同輸入特征不屬于相似的類別(比如顏色和大小),那么 LN 的處理可能會(huì)降低模型的表達(dá)能力。
3.3 Weight Normalization —— 參數(shù)規(guī)范化
BN 和 LN 均將規(guī)范化應(yīng)用于輸入的特征數(shù)據(jù)x,而 WN 則另辟蹊徑,將規(guī)范化應(yīng)用于線性變換函數(shù)的權(quán)重 w,這就是 WN 名稱的來源。
乍一看,這一方法似乎脫離了我們前文所講的通用框架?
并沒有。其實(shí)從最終實(shí)現(xiàn)的效果來看,異曲同工。我們來推導(dǎo)一下看。
對照一下前述框架:
我們只需令:
就完美地對號(hào)入座了!
回憶一下,BN 和 LN 是用輸入的特征數(shù)據(jù)的方差對輸入數(shù)據(jù)進(jìn)行 scale,而 WN 則是用 神經(jīng)元的權(quán)重的歐氏范式對輸入數(shù)據(jù)進(jìn)行 scale。雖然在原始方法中分別進(jìn)行的是特征數(shù)據(jù)規(guī)范化和參數(shù)的規(guī)范化,但本質(zhì)上都實(shí)現(xiàn)了對數(shù)據(jù)的規(guī)范化,只是用于 scale 的參數(shù)來源不同。
另外,我們看到這里的規(guī)范化只是對數(shù)據(jù)進(jìn)行了 scale,而沒有進(jìn)行 shift,因?yàn)槲覀兒唵蔚亓?u=0. 但事實(shí)上,這里留下了與 BN 或者 LN 相結(jié)合的余地——那就是利用 BN 或者 LN 的方法來計(jì)算輸入數(shù)據(jù)的均值 u 。
WN 的規(guī)范化不直接使用輸入數(shù)據(jù)的統(tǒng)計(jì)量,因此避免了 BN 過于依賴 mini-batch 的不足,以及 LN 每層唯一轉(zhuǎn)換器的限制,同時(shí)也可以用于動(dòng)態(tài)網(wǎng)絡(luò)結(jié)構(gòu)。
3.4 Cosine Normalization —— 余弦規(guī)范化
Normalization 還能怎么做?
我們再來看看神經(jīng)元的經(jīng)典變換
$$f_{w}({x})={w}/cdot{x}$$
對輸入數(shù)據(jù)x的變換已經(jīng)做過了,橫著來是 LN,縱著來是 BN。
對模型參數(shù)w的變換也已經(jīng)做過了,就是 WN。
好像沒啥可做的了。
然而天才的研究員們盯上了中間的那個(gè)點(diǎn),對,就是 . 。
他們說,我們要對數(shù)據(jù)進(jìn)行規(guī)范化的原因,是數(shù)據(jù)經(jīng)過神經(jīng)網(wǎng)絡(luò)的計(jì)算之后可能會(huì)變得很大,導(dǎo)致數(shù)據(jù)分布的方差爆炸,而這一問題的根源就是我們的計(jì)算方式——點(diǎn)積,權(quán)重向量w和 特征數(shù)據(jù)向量x的點(diǎn)積。向量點(diǎn)積是無界(unbounded)的啊!
那怎么辦呢?我們知道向量點(diǎn)積是衡量兩個(gè)向量相似度的方法之一。哪還有沒有其他的相似度衡量方法呢?有啊,很多啊!夾角余弦就是其中之一啊!而且關(guān)鍵的是,夾角余弦是有確定界的啊,[-1, 1] 的取值范圍,多么的美好!仿佛看到了新的世界!
于是,Cosine Normalization 就出世了。他們不處理權(quán)重向量w,也不處理特征數(shù)據(jù)向量x ,就改了一下線性變換的函數(shù):
然后就沒有然后了,所有的數(shù)據(jù)就都是 [-1, 1] 區(qū)間范圍之內(nèi)的了!
不過,回過頭來看,CN 與 WN 還是很相似的。我們看到上式中,分子還是 w和x的內(nèi)積,而分母則可以看做用w和 x二者的模之積進(jìn)行規(guī)范化。對比一下 WN 的公式:
CN 通過用余弦計(jì)算代替內(nèi)積計(jì)算實(shí)現(xiàn)了規(guī)范化,但成也蕭何敗蕭何。原始的內(nèi)積計(jì)算,其幾何意義是 輸入向量在權(quán)重向量上的投影,既包含 二者的夾角信息,也包含 兩個(gè)向量的scale信息。去掉scale信息,可能導(dǎo)致表達(dá)能力的下降,因此也引起了一些爭議和討論。具體效果如何,可能需要在特定的場景下深入實(shí)驗(yàn)。
現(xiàn)在,BN, LN, WN 和 CN 之間的來龍去脈是不是清楚多了?
4. Normalization 為什么會(huì)有效?
我們以下面這個(gè)簡化的神經(jīng)網(wǎng)絡(luò)為例來分析。
4.1 Normalization 的權(quán)重伸縮不變性
上述規(guī)范化方法均有這一性質(zhì),這是因?yàn)椋?dāng)權(quán)重w伸縮時(shí),對應(yīng)的均值和標(biāo)準(zhǔn)差均等比例伸縮,分子分母相抵。
權(quán)重伸縮不變性可以有效地提高反向傳播的效率。
由于
因此,權(quán)重的伸縮變化不會(huì)影響反向梯度的 Jacobian 矩陣,因此也就對反向傳播沒有影響,避免了反向傳播時(shí)因?yàn)闄?quán)重過大或過小導(dǎo)致的梯度消失或梯度爆炸問題,從而加速了神經(jīng)網(wǎng)絡(luò)的訓(xùn)練。
權(quán)重伸縮不變性還具有參數(shù)正則化的效果,可以使用更高的學(xué)習(xí)率。
由于
因此,下層的權(quán)重值越大,其梯度就越小。這樣,參數(shù)的變化就越穩(wěn)定,相當(dāng)于實(shí)現(xiàn)了參數(shù)正則化的效果,避免參數(shù)的大幅震蕩,提高網(wǎng)絡(luò)的泛化性能。
4.2 Normalization 的數(shù)據(jù)伸縮不變性
數(shù)據(jù)伸縮不變性僅對 BN、LN 和 CN 成立。因?yàn)檫@三者對輸入數(shù)據(jù)進(jìn)行規(guī)范化,因此當(dāng)數(shù)據(jù)進(jìn)行常量伸縮時(shí),其均值和方差都會(huì)相應(yīng)變化,分子分母互相抵消。而 WN 不具有這一性質(zhì)。
數(shù)據(jù)伸縮不變性可以有效地減少梯度彌散,簡化對學(xué)習(xí)率的選擇。
每一層神經(jīng)元的輸出依賴于底下各層的計(jì)算結(jié)果。如果沒有正則化,當(dāng)下層輸入發(fā)生伸縮變化時(shí),經(jīng)過層層傳遞,可能會(huì)導(dǎo)致數(shù)據(jù)發(fā)生劇烈的膨脹或者彌散,從而也導(dǎo)致了反向計(jì)算時(shí)的梯度爆炸或梯度彌散。
數(shù)據(jù)的伸縮變化也不會(huì)影響到對該層的權(quán)重參數(shù)更新,使得訓(xùn)練過程更加魯棒,簡化了對學(xué)習(xí)率的選擇。
參考文獻(xiàn)
[1] Sergey Ioffe and Christian Szegedy. Accelerating Deep Network Training by Reducing Internal Covariate Shift.
[2] Jimmy L. Ba, Jamie R. Kiros, Geoffrey E. Hinton. [1607.06450] Layer Normalization.
[3] Tim Salimans, Diederik P. Kingma. A Simple Reparameterization to Accelerate Training of Deep Neural Networks.
[4] Chunjie Luo, Jianfeng Zhan, Lei Wang, Qiang Yang. Using Cosine Similarity Instead of Dot Product in Neural Networks.
[5] Ian Goodfellow, Yoshua Bengio, Aaron Courville. Deep Learning.
本文在寫作過程中,參考了以下各位的回答,特此致謝。
@魏秀參的回答: 深度學(xué)習(xí)中 Batch Normalization為什么效果好?
@孔濤的回答: 深度學(xué)習(xí)中 Batch Normalization為什么效果好?
@王峰的回答: 深度學(xué)習(xí)中 Batch Normalization為什么效果好?
@lqfarmer的回答: Weight Normalization 相比batch Normalization 有什么優(yōu)點(diǎn)呢?
@Naiyan Wang的回答: Batch normalization和Instance normalization的對比?
@YJango的文章: YJango的Batch Normalization--介紹
-End-
推薦閱讀
2020年醫(yī)學(xué)圖像處理領(lǐng)域值得關(guān)注的期刊和會(huì)議
清華劉知遠(yuǎn)教授:好的研究想法從哪里來?
歡迎關(guān)注我的極術(shù)專欄:AI搬運(yùn)小能手,給您分享最前沿靠譜的高質(zhì)量AI技術(shù)干貨。
審核編輯 黃昊宇
-
機(jī)器學(xué)習(xí)
+關(guān)注
關(guān)注
66文章
8438瀏覽量
132938 -
深度學(xué)習(xí)
+關(guān)注
關(guān)注
73文章
5512瀏覽量
121421
發(fā)布評論請先 登錄
相關(guān)推薦
評論