有人認為,用低精度訓練機器學習模型會限制訓練的精度,事實真的如此嗎?本文中,斯坦福大學的DAWN人工智能研究院介紹了一種名為bit recentering的技術,它可以用低精度的計算實現高準確度的解決方案。以下是論智對原文的編譯,文末附原論文地址。
低精度計算在機器學習中已經吸引了大量關注。一些公司甚至已經開始研發能夠原生支持并加速低精度操作的硬件了,比如微軟的腦波計劃(Project Brainwave)和谷歌的TPU。雖然使用低精度計算對系統來說有很多好處,但是低精度方法仍然主要用于推理,而非訓練。此前,低精度訓練算法面臨著一個基本困境(fundamental tradeoff):當使用較少的位進行計算時,舍棄誤差就會增加,這就限制了訓練的準確度。根據傳統觀點,這種制約限制了研究人員在系統中部署低精度訓練算法的能力,但是這種限制能否改變?是否有可能設計一種使用低精度卻不會限制準確度的算法呢?
答案是肯定的。在某些情況下我們可以從低精度訓練中獲得高準確度的解決方案,在這里我們使用了一種新的隨機梯度下降方法,稱為高準確度低精度(HALP)法。HALP比之前的算法表現更好,因為它減少了兩個限制低精度隨機梯度下降準確度的噪聲源:梯度方差和舍棄誤差。
為了減少梯度方差帶來的噪音,HALP使用常見的SVRG(stochastic variance-reduced gradient)技術。SVRG能定期使用完全梯度來減少隨機梯度下降中使用的梯度樣本的方差。
為了降低量化數字帶來的噪聲,HALP使用了名為“bit centering”的新技術,它背后的原理是,當我們接近最優點時,梯度漸變的幅度變小。也就是說攜帶的信息變少,于是我們能對其進行壓縮。通過動態地重新調整低精度數字,我們可以在算法收斂時降低量化噪聲。
將這兩種技術結合,HALP能夠以和全精度SVRG同樣的線性收斂率生成任意準確地解決方案,同時在低精度迭代時使用的是固定位數。這個結果顛覆了有關低精度訓練算法的傳統觀點。
為什么低精度的隨機梯度下降有所限制?
首先先交代一下背景:我們想要解決以下這個訓練問題:
這是用來訓練許多機器學習模型(包括深度神經網絡)的經典實證問題:讓風險最小化。解決這個問題的標準方法之一是隨機梯度下降,它是一種通過運行接近最佳值的迭代算法。
在每次迭代時,it是從{1,..., N}中隨機挑選的一個指數,我們雖然想運行這樣的算法,但是要保證迭代wt是低精度的。也就是說,我們希望它們使用較少位的定點運算(通常為8位或16位)。但是,當直接對隨機梯度下降更新規則而進行這項操作時,我們遇到了問題:問題的解決方案w可能無法再選中的定點表示中顯示出來。例如,如果一個8位的定點表示,可以儲存{-128,-127,…,127}之間的整數,正確的解決方法是w*=100.5,那么我們與解決方案的距離不可能小于0.5,因為我們不能表示非整數。除此之外,將梯度轉換為定點導致的舍棄誤差可能會減慢收斂速度,這也影響了低精度SGD的準確性。
Bit Centering
當我們運行隨機梯度下降時,在某種意義上,我們世紀正對一堆梯度樣本進行平均(或總結)。Bit Centering背后的關鍵思想是隨著梯度漸變逐漸變小,我們可以用同樣的位數、以較小的誤差對它們求平均值。想要知道為什么,想像一下,你想對[-100, 100]之間的數字求平均值,然后和[-1, 1]的平均值作比較。在前一個集合中,我們需要選擇一個定點表示,它可以覆蓋整個集合(例如{-128,-127,…,127})。而在第二個集合中,我們選擇的定點要覆蓋[-1, 1],例如{-128/127,-127/127,..., 126/127,127/127}。這就意味著在固定位數情況下,第二種情況中的相鄰可表示數字之間的差值比第一種情況更小,因此舍棄誤差也更低。
這個關鍵的想法讓我們得到了啟發。為了在[-1, 1]中求出比[-100, 100]中更少誤差的平均數,我們需要用一個不同的定點表示,即我們應該不斷更新低精度表示:隨著梯度漸變得越小,我們應該用位數更小的定點數字,覆蓋更小的范圍。
但是我們該如何更新表示呢?我們要覆蓋的范圍到底多大?如果目標是帶有參數μ的強凸,那么不管我們何時在某一點w采取完整的梯度漸變是,都可以用以下公式限制最佳位置
這種不等式為最終的解決方案提供了一系列可能的定位,所以無論何時計算完整梯度,我們都可以重新居中并縮放低精度表示以覆蓋此范圍。下圖說明了這一過程。
HALP
HALP是運行SVRG并在每個時期都使用具有完全梯度的bit centering更新低精度表示的算法。原論文有對這一方法的具體描述,在這里我們只簡單做些介紹。首先,我們證明了,對于強凸的Lipschitz光滑函數,只要我們使用的位數b滿足
其中κ是該問題的條件數字,那么為了適當設置尺寸和時間長度,HALP將以線性速度收斂到任意準確度的解。更顯然的是,0<γ<1,
其中wk+1表示第K次迭代后的值。下表表現了這一變化過程
圖表通過對具有100個特征和1000個樣本的合成數據集進行線性回歸,來評估HALP。將它與全精度梯度下降、SVRG、低精度的梯度下降和低精度的SVRG進行比較。需要注意的是,盡管只有8位(受到浮點錯誤的限制),HALP仍能收斂到精度非常高的結果上。在這種情況下,HALP可以比全精度SVRG收斂到更高精度的結果中,因為HALP較少使用浮點運算,因此對浮點的非準確性不敏感。
-
機器學習
+關注
關注
66文章
8438瀏覽量
132900
原文標題:斯坦福DAWN實驗室實現用低精度計算產生高準確度結果
文章出處:【微信號:jqr_AI,微信公眾號:論智】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
相關推薦
評論