1 簡介
預訓練模型BERT以及相關的變體自從問世以后基本占據了各大語言評測任務榜單,不斷刷新記錄,但是,BERT龐大的參數量所帶來的空間跟時間開銷限制了其在下游任務的廣泛應用?;诖?,人們希望能通過Bert得到一個更小規模的模型,同時基本具備Bert的能力,從而為下游任務的大規模應用提供可能性。目前許多跟Bert相關的蒸餾方法被提出來,本章節就來分析下這若干蒸餾方法之間的細節以及差異。
知識蒸餾由兩個模型組成,teacher模型跟student模型,一般teacher模型規模跟參數量都比較龐大,所以能力更強,而student模型規模比較小,如果直接訓練的話效果比較有限,所以是先訓練teacher模型,讓它學到充足的知識,然后用student模型去學習teacher模型的行為,從而實現將知識從teacher模型轉移到student模型,使得student模型能在較小的參數量的同時具備接近大模型的能力。在蒸餾過程中,最常見的student模型部分的loss,就是對于同一個數據,將teacher模型的預測的soft概率作為ground truth,讓teacher模型去學習從而預測得到相同的結果,這部分teacher模型跟student模型預測的概率之間距離就是蒸餾最常見的loss(通常是交叉熵)。蒸餾學習希望student模型學到teacher模型的能力,從而預測的結果跟teacher模型預測的soft概率足夠接近,也就是希望這部分的loss盡可能的小。
2 DualTrain+SharedProj
以往的知識蒸餾雖然可以有效的壓縮模型尺寸,但很難將teacher模型的能力蒸餾到一個更小詞表的student模型中,而DualTrain+SharedProj解決了這個難題。它主要針對Bert的詞表大小跟嵌入緯度做了縮簡,其余部分,包括模型結構跟層數保持跟teacher模型(Bert Base)一致,從而實現將知識從teacher模型遷移到student模型中。
圖1: DualTrain+SharedProj框架
區別于其他蒸餾方法,DualTrain+SharedProj有兩個特別的地方,一個是Dual Training, 另一個是Shared Projection。Dual Training主要是為了解決teacher模型跟student模型不共用詞表的問題,在蒸餾過程中,對于teacher模型,會隨機選擇teacher模型或者student模型的詞表去進行分詞,可以理解就是混合了teacher模型跟student模型的詞表,這種方式可以對齊兩個規模不同的的詞表。例如圖中左邊部分,I和machine用的是teacher模型的分詞結果而其余token用的是student模型的分詞結果。第二部分是Shared Projection,這部分很好理解,因為student模型嵌入層緯度縮小了,導致每個transformer層的緯度都縮小了,但是我們希望student模型跟teacher模型的transformer層的參數足夠接近,所以這里需要一個可訓練的矩陣將兩個不同維度的transformer層參數縮放到同一個維度才能進行比較。如果是對teacher模型的參數進行縮放,就叫做down projection,如果是對student模型參數進行的縮放,就叫做up projection。同時,12層的transformer參數共用同一個縮放矩陣,所以叫做shared projection。例如下圖,下標t,s分別代表teacher模型跟student模型。
圖2: up projection損失
圖3: DualTrain+SharedProj的損失函數
在蒸餾過程中,會將teacher模型跟student模型都在監督數據上進行訓練,將兩個模型預測結果的損失加上兩個模型之間transformer層的參數之間的距離的損失作為最終損失,去更新student模型的參數。最終實驗效果可也表明,隨著student模型的隱藏層緯度縮減得越厲害,模型的效果也會逐漸變差。
圖4: DualTrain+SharedProj的實驗效果
DualTrain+SharedProj是很少見的student模型跟teacher模型不共享詞表的一種蒸餾方式,通過縮小詞表跟縮減嵌入層緯度,可以很大程度的減少模型的尺寸。同時也要注意,尺寸縮小得厲害,student模型的效果也下降地越厲害。另外有一點我不太理解,只通過一個dual training過程就可以對齊兩個詞表了嗎?是不是要蒸餾開始之前先對teacher模型,混合兩個詞表的分詞結果做下預訓練會更加合理?
3DistillBERT
DistilBERT是通過一種比較常規的蒸餾方法得到的,它的teacher模型依舊是Bert Base,DistilBERT沿用了Bert的結構,但是transfromer層數只有6層(Bert Base有12層),同時還將嵌入層token-type embedding跟最后的pooling層移除。為了讓DistilBERT有一個更加合理的初始化,DistilBERT的transformer參數來源于Bert Base,每隔兩層transformer取其中一層的參數來作為DistilBERT的參數初始化。
在蒸餾過程中,除了常規的蒸餾部分的loss,還加入了一個自監督訓練的loss(MLM任務的loss),除此之外,實驗還發現加入一個詞嵌入的loss有利于對齊teacher模型跟student模型的隱藏層表征。
DistilBERT是一種常見的通過蒸餾得到的方法,基本上是通過減少transformer的層數來減少模型尺寸,同時加速模型推理的。
4LSTM
蒸餾學習并不要求teacher模型跟student模型要隸屬于同一種模型架構,于是就有人腦洞大開,想用BiLSTM作為student模型來承載Bert Base龐大的能力。這里的teacher模型依舊是Bert Base,student模型分為三個部分,第一部分是詞嵌入層,第二部分是雙向LSTM+pooling,這里會將BiLSTM得到的隱藏層狀態通過max pooling生成句子的表征,第三部分是全連接層,直接輸出各個類別的概率。
在蒸餾開始之前,需要先在特定任務的監督數據集上對teacher模型進行微調,因為是分類任務,所以Bert Base跟后面的全連接層會一起更新參數,從而讓teacher模型適配下游任務。在蒸餾過程中,student模型的損失分為三部分,第一部分依舊是常規的根據teacher模型預測的soft概率跟student模型預測的概率之間的交叉熵損失。第二部分是在監督數據下student模型預測的結果跟真實標簽結果之間的交叉熵損失。第三部分是teacher模型跟student模型生成表征之間的KL距離,也就是BiLSTM+pooling跟Bert base最后一層狀態輸出之間的距離,但是由于這兩者可能維度不一樣,所以這里也需要引入一個全連接層來縮放。
圖5: BiLSTM的蒸餾過程
圖6: BiLSTM蒸餾的效果對比
可以看得到通過蒸餾得到的BiLSTM明顯優于直接finetune的,這里證明了蒸餾學習的有效性。除此之外,BiLSTM本身的準確率就很高了,說明任務比較簡單(要不然蒸餾過后的BiLSTM準確率比teacher模型Bert Base還高不是很詭異嘛?),所以并不能說明把Bert Base蒸餾到BiLSTM是個合適的選擇。LSTM本身結構的局限性導致了很難完全學習到transformer的知識跟能力,筆者以前也在一些比較難的數據集上嘗試過類似的做法,但是最終作為student模型的LSTM的效果跟teacher模型的之間的差距還是比較大,并且泛化能力比較差。
5 PDK
PKD想通過蒸餾學習將Bert Base的transformer層數進行壓縮,但是常規的方式只學習teacher模型最后一層的結果,雖然能在訓練集上取得可以媲美teacher模型的效果,但是在測試集的表現很快就收斂了。這種現象看起來像是在訓練集上過擬合了,從而影響了student模型的泛化能力?;诖?,PKD在原本的基礎上加上了新的約束項,驅使student模型去學習模仿teacher模型的中間過程。具體的有兩種可能方式,第一種就是讓student模型去學習teacher模型transformer每隔幾層的結果,第二種是讓student模型去學習teacher模型最后幾層transformer的結果。
蒸餾過程的損失函數包括三個部分,第一部分還是常規的teacher模型預測的soft概率和student模型預測結果之間的交叉熵損失,第二部分是student模型預測概率跟真實標簽之間的交叉熵損失,第三部分就是teacher模型跟student模型之間中間狀態的距離,這里用的[CLS]位置的表征。
6TinyBert
TinyBert的特別之處在于它的蒸餾過程分為兩個階段。第一階段是通用蒸餾,teacher model是預訓練好的Bert, 可以幫助TinyBert學習到豐富的知識,具備強大的通用能力,第二階段是特定任務蒸餾,teacher moder是經過finetune的Bert, 使得TinyBert學習到特定任務下的知識。兩個蒸餾環節的設計,能保證TinyBert強大的通用能力跟特定任務下的提升。
在每個蒸餾環節下,student模型的蒸餾分為三個部分,Embedding-layer Distillation,Transformer-layer Distillation, Prediction-layer Distillation。Embedding-layer Distillation是詞嵌入層的蒸餾,使得TinyBert更小維度的embedding輸出結果盡可能的接近Bert的embedding輸出結果。Transformer-layer Distillation是其中transformer層的蒸餾,這里的蒸餾采用的是隔k層蒸餾的方式。也就是,假如teacher model的Bert的transformer有12層,如果TinyBert的transformer設計有4層,那么就是就是每隔3層蒸餾,TinyBert的第1,2,3,4層transformer分別學習的是Bert的第3,6,9,12層transformer層的輸出。Prediction-layer Distillation主要是對齊TinyBert跟Bert在預測層的輸出,這里學習的是預測層的logit,也就是概率值。前面兩部分的損失都是MSE計算,因為teacher模型跟student模型在嵌入層跟隱藏層的維度不一致,所以這里需要相應的線性映射將student模型的中間輸出映射到跟teacher 模型一樣的維度,最后一部分的損失是通過交叉熵損失計算的。通過這三部分的學習,能保證TinyBert在中間層跟最后預測層都學習到Bert相應的結果,進而保證準確率。
圖7: TinyBert框架
TinyBert的兩階段蒸餾過程能驅使student模型能學到teacher模型的通用知識和特定領域知識,保證student模型在下游任務的表現,是很值得借鑒的一種訓練技巧。
7 MOBILEBERT
MOBILEBERT可能是目前性價比最高的一種蒸餾方式了(可能是筆者眼界有限),無論是從學習的目標,還是整個訓練的方式,考慮都很周全。MOBILEBERT的student模型跟teacher模型的網絡層數保持一致,相關的模型結構有所變化,首先是student模型跟teacher模型都新增了bottleneck,用于縮放內部表示尺寸,在后面loss部分會展開介紹,其次是student模型里將FFN改成堆疊的FFN,最后是移除了layer normalization跟將激活函數由gelu換成relu.
在蒸餾過程中,student模型的損失包括兩個部分。第一個部分是student模型和teacher模型之間的feature map的距離,這里的feature map指的是每一層transformer輸出的結果。在這里,為了能讓student模型的隱藏層維度比teacher模型的隱藏層維度更小從而實現模型壓縮,這里的student模型跟teacher模型的transformer結構都加入了bottleneck,也就是圖中綠色梯形的部分,通過這些bottleneck可以對文本表征尺寸進行縮放,從而實現teacher模型跟student模型各自在每一個transformer內部表示尺寸不同,但是輸入和輸出尺寸一致,所以就可能用內部表示尺寸小的student模型去學習內部表示尺寸大的teacher模型的能力跟知識。第二部分是兩個模型每一層transformer中attention的距離,這部分loss是為了利用self attention從teacher模型中學習到相關內容從而更好得學習到第一部分的feature map。
圖8: MOBILEBERT相關的網絡結構
MOBILEBERT的蒸餾過程是漸近式的,在蒸餾學習第L層的參數時會固定L層以下的參數,一層一層的學習teacher模型的,直到學完全部層數。
圖9: MOBILEBERT的漸近式知識遷移過程
在完成蒸餾學習后,MOBILEBERT還會在做進一步的預訓練,預訓練有三部分的loss,第一部分跟第二部分是BERT預訓練的MLM跟NSP任務的loss,第三部分是teacher模型跟student模型在[MASK]位置的預測概率之間的交叉熵損失。
8總結
為了直觀的對比上面提及的蒸餾方法的壓縮效率和模型效果,我們匯總了若干種模型的具體信息以及在MRPC數據集上的表現??傮w來說,有以下一些相關結論。
a)壓縮效率越高往往會伴隨著模型效果的持續下降。
b)Student模型的上限就是teacher模型。對于同一個student模型,并不是teacher模型越大student模型效果就會越好。因為越大的teacher模型,意味著更大的壓縮效率,也意味著更嚴重的性能下降。
c)只學習teacher模型最后的預測的soft概率是遠遠不夠的,需要對teacher模型中間的表征或者參數也進行學習,才能進一步保證student模型的效果。
d)縮減transformer層數或者縮減隱藏層狀態緯度都可以壓縮模型,對于縮減隱藏層狀態維度,用MOBILEBERT那種bottleneck的方式優于常規的通過一個額外的映射來對齊模型尺寸的方式??s減隱藏層狀態維度的方式的模型壓縮效率的上限更高。
e)漸進性學習方式是有效的。也就是固定下層的參數,只更新當前層的參數,依次迭代直至更新完student模型全部層。
f)分階段蒸餾是有效的。先學習通用的teacher模型,然后再學習特定任務下finetune的teacher模型。
g)跨模型結構的蒸餾是有效的。用BiLSTM來學習Bert Base的能力比直接finetune BiLSTM的效果要好。
Model | type | Compress Factor | MRPC(f1) |
Bert Base | 1 | 88.9 | |
DualTrain+SharedProjUp |
192 96 48 |
5.74 19.41 61.94 |
84.9 84.9 79.3 |
DistilBERT | 1.67 | 87.5 | |
PKD |
6 3 |
1.64 2.40 |
85.0 80.7 |
TinyBert | 4 | 7.50 | 86.4 |
MOBILEBERT | 4.30 | 88.8 |
參考文獻
1.(2020) EXTREME LANGUAGE MODEL COMPRESSION WITH OPTIMAL SUBWORDS AND SHARED PROJECTIONS
https://openreview.net/pdf?id=S1x6ueSKPr
2. (2020) DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter
https://arxiv.org/abs/1910.01108
3. (2020) DISTILLING BERT INTO SIMPLE NEURAL NETWORKS WITH UNLABELED TRANSFER DATA
https://arxiv.org/pdf/1910.01769.pdf
4. (2019) Patient Knowledge Distillation for BERT Model Compression
https://arxiv.org/pdf/1908.09355.pdf
5.(2020)TINYBERT: DISTILLING BERT FOR NATURAL LAN- GUAGE UNDERSTANDING
https://openreview.net/attachment?id=rJx0Q6EFPB&name=original_pdf
6. (2020) MOBILEBERT: TASK-AGNOSTIC COMPRESSION OF BERT BY PROGRESSIVE KNOWLEDGE TRANSFER
https://openreview.net/pdf?id=SJxjVaNKwB
審核編輯 :李倩
-
模型
+關注
關注
1文章
3243瀏覽量
48840 -
LSTM
+關注
關注
0文章
59瀏覽量
3752
原文標題:Bert系列之知識蒸餾
文章出處:【微信號:zenRRan,微信公眾號:深度學習自然語言處理】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
相關推薦
評論