在這一系列文章中,我們主要研究如何用單獨的GPU,在CIFAR-10圖像分類數據集上高效地訓練殘差網絡(Residual networks)。
為了記錄這一過程,我們計算了網絡從零開始訓練到94%的精確度所需的時間。這一基準來自最近的DAWNBench競賽。在競賽結束后,單個GPU上的最好成績是341秒,八個GPU上最好成績是174秒。
Baseline
在這部分中,我們復制了一個基線,在6分鐘內訓練CIFAR10,之后稍稍加速。我們發現,在GPU的FLOPs計算完之前,仍有很大的提升空間。
過去幾個月,我一直在研究如何能更快度訓練深度神經網絡。這個想法是從今年年初萌生的,當時我正和Myrtle的Sam Davis進行一個項目。我們將用于自動語音識別的大型循環神經網絡壓縮后,部署到FPGAs上,重新訓練模型。來自Mozilla的基線在16個GPU上訓練了一個星期。后來,經過Sam的努力,我們在英偉達的Volta GPUs上進行混淆精度訓練,得以將訓練時間縮短了100倍,迭代時間在單個GPU上只需要不到一天的時間。
這一結果讓我思考還有什么可以實現加速?幾乎與此同時,斯坦福大學的研究人員們開啟了DAWNBench挑戰賽,比較多個深度學習基線上的訓練速度。最受人關注的就是訓練圖像分類模型在CIFAR10上達到94%的測試精確度,在ImageNet上達到93%、top5的成績。圖像分類是深度學習研究的熱門領域,但是訓練速度仍需要數小時。
到了四月份,挑戰賽接近尾聲,CIFAR10上最快的單個GPU訓練速度來自fast.ai的一名學生Ben Johnson,他在不到6分鐘(341秒)的時間里訓練出了94%的精確度。這一創新主要是混淆精度的訓練,他選擇了一個較小的網絡,有足夠的能力處理任務并且可以用更高的學習速率加速隨機梯度下降。
這時我們不禁提出一個問題:這種341秒訓練出來的94%測試精度,在CIFAR10上的表現怎么樣?該網絡的架構是一個18層的殘差網絡,如下所示。在這個案例中,圖層的數量表示卷積(紫色)和完全連接層(藍色)的序列深度:
網絡通過隨機梯度下降訓練了35個epoch,學習速率圖如下:
現在我們假設在一個英偉達Volta V100 GPU上用100%的計算力,訓練將需要多長時間。網絡在一張32×32×3的CIFAR10圖像上進行前向和后向傳遞時需要大約2.8×109FLOPs。假設參數更新不耗費計算力,那么在50000張圖像訓練35個epoch應該會在5×1015FLOPs以內完成。
Tesla V100有640個Tensor Cores,能支持125 TeraFLOPS的深度學習性能。
假設我們能發揮100%的計算力,那么訓練會在40秒內完成,這么看來341秒的成績還有很大的提升空間。
有了40秒這個目標,我們就開始了自己的訓練。首先是用上方的殘差網絡重新復現基線CIFAR10的結果。我用PyTorch創建了一個網絡,重新復制了學習速率和超參數。在AWS p3.2的圖像上用單個V100 GPU訓練,3/5的運行結果在356秒內達到了94%的精確度。
基線建好后,下一步是尋找可以立即使用的簡單改進方法。首先我們觀察到:網絡開頭是由黃色和紅色的兩個連續norm-ReLU組成的,在紫色卷積之后,我們刪去重復部分,同樣在epoch 15也發生了這樣的情況。進行調整后,網絡架構變得更簡單,4/5的運行結果在323秒內達到了94%的精確度!刷新了記錄!
另外我們還觀察到,圖像處理過程中的一些步驟(填充、標準化、位移等等)每經過訓練集一次就要重新處理一遍,會浪費很多時間。雖然提前預處理可以用多個CPU處理器減輕這一結果,但是PyTorch的數據下載器會從每次數據迭代中開始新一次的處理。這一配置時間是很短的,尤其在CIFAR10這樣的小數據集上。只要在訓練前做了準備,減少預處理壓力,就能減少處理次數。遇到更復雜的任務,需要更多預處理步驟或多個GPU時,就會在每個epoch之間保持數據下載器的處理。溢出了重復工作、減少了數據下載器后,訓練時間達到了308秒。
繼續研究后我們發現,大部分預處理時間都花在了召集隨機數字生成器,選擇數據增強而不是為它們本身增強。在完全訓練時期,我們對隨機數字生成器執行了幾百萬個單獨命令,把它們結合在一個較小的命令中,每個epoch可以省去7秒訓練時間。最終的訓練時間縮短到了297秒。這一過程的代碼可以點擊:github.com/davidcpage/cifar10-fast/blob/master/experiments.ipynb
-
gpu
+關注
關注
28文章
4729瀏覽量
128901 -
圖像分類
+關注
關注
0文章
90瀏覽量
11914 -
深度學習
+關注
關注
73文章
5500瀏覽量
121117
原文標題:如何訓練你的ResNet(一):復現baseline,將訓練時間從6分鐘縮短至297秒
文章出處:【微信號:jqr_AI,微信公眾號:論智】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
相關推薦
評論