微軟研究人員在ICLR 2018發表了一種新的GAN(對抗網絡生成)訓練方法,boundary-seeking GAN(BGAN),可基于離散值訓練GAN,并提高了GAN訓練的穩定性。
對抗生成網絡
首先,讓我們溫習一下GAN(對抗生成網絡)的概念。簡單來說,GAN是要生成“以假亂真”的樣本。這個“以假亂真”,用形式化的語言來說,就是假定我們有一個模型G(生成網絡),該模型的參數為θ,我們要找到最優的參數θ,使得模型G生成的樣本的概率分布Qθ與真實數據的概率分布P盡可能接近。即:
其中,D(P, Qθ)為P與Qθ差異的測度。
GAN的主要思路,是通過引入另一個模型D(判別網絡),該模型的參數為φ,然后定義一個價值函數(value function),找到最優的參數φ,最大化這一價值函數。比如,最初的GAN(由Goodfellow等人在2014年提出),定義的價值函數為:
其中,Dφ為一個使用sigmoid激活輸出的神經網絡,也就是一個二元分類器。價值函數的第一項對應真實樣本,第二項對應生成樣本。根據公式,D將越多的真實樣本歸類為真(1),同時將越多的生成樣本歸類為假(0),D的價值函數的值就越高。
GAN的精髓就在于讓生成網絡G和判別網絡D彼此對抗,在對抗中提升各自的水平。形式化地說,GAN求解以下優化問題:
如果你熟悉Jensen-Shannon散度的話,你也許已經發現了,之前提到的最初的GAN的價值函數就是一個經過拉伸和平移的Jensen-Shannon散度:2 * DJSD(P||Qθ) - log 4. 除了這一Jensen-Shannon散度的變形外,我們還可以使用其他測度衡量分布間的距離,Nowozin等人在2016年提出的f-GAN,就將GAN的概念推廣至所有f-散度,例如:
Jensen-Shannon
Kullback–Leibler
Pearson χ2
平方Hellinger
當然,實際訓練GAN時,由于直接計算這些f-散度比較困難,往往采用近似的方法計算。
GAN的缺陷
GAN有兩大著名的缺陷:難以處理離散數據,難以訓練。
GAN難以處理離散數據
為了基于反向傳播和隨機梯度下降之類的方法訓練網絡,GAN要求價值函數在生成網絡的參數θ上完全可微。這使得GAN難以生成離散數據。
假設一下,我們給生成網絡加上一個階躍函數(step function),使其輸出離散值。這個階躍函數的梯度幾乎處處為0,這就使GAN無法訓練了。
GAN難以訓練
從直覺上說,訓練判別網絡比訓練生成網絡要容易得多,因為識別真假樣本通常比偽造真實樣本容易。所以,一旦判別網絡訓練過頭了,能力過強,生成網絡再怎么努力,也無法提高,換句話說,梯度消失了。
另一方面,如果判別網絡能力太差,胡亂分辨真假,甚至把真的誤認為假的,假的誤認為真的,那生成網絡就會很不穩定,會努力學習讓生成的樣本更假——因為弱智的判別網絡會把某些假樣本當成真樣本,卻把另一些真樣本當成假樣本。
還有一個問題,如果生成網絡湊巧在生成某類真樣本上特別得心應手,或者,判別網絡對某類樣本的辨別能力相對較差,那么生成網絡會揚長避短,盡量多生成這類樣本,以增大騙過判別網絡的概率,這就導致了生成樣本的多樣性不足。
所以,判別網絡需要訓練得恰到好處才可以,這個火候非常難以控制。
強化學習和BGAN
那么,該如何避免GAN的缺陷呢?
我們先考慮離散值的情況。之所以GAN不支持生成離散值,是因為生成離散值導致價值函數(也就是GAN優化的目標)不再處處可微了。那么,如果我們能對GAN的目標做一些手腳,使得它既處處可微,又能衡量離散生成值的質量,是不是可以讓GAN支持離散值呢?
關鍵在于,我們應該做什么樣的改動?關于這個問題,可以從強化學習中得到靈感。實際上,GAN和強化學習很像,生成網絡類似強化學習中的智能體,而騙過判別網絡類似強化學習中的獎勵,價值函數則是強化學習中也有的概念。而強化學習除了可以根據價值函數進行外,還可以根據策略梯度(policy gradient)進行。根據價值函數進行學習時,基于價值函數的值調整策略,迭代計算價值函數,價值函數最優,意味著當前策略是最優的。而根據策略梯度進行時,直接學習策略,通過迭代計算策略梯度,調整策略,取得最大期望回報。
咦?這個策略梯度看起來很不錯呀。引入策略梯度解決了離散值導致價值函數不是處處可微的問題。更妙的是,在強化學習中,基于策略梯度學習,有時能取得比基于值函數學習更穩定、更好的效果。類似地,引入策略梯度后GAN不再直接根據是否騙過判別網絡調整生成網絡,而是間接基于判別網絡的評價計算目標,可以提高訓練的穩定度。
BGAN(boundary-seeking GAN)的思路正是如此。
BGAN論文的作者首先證明了目標密度函數p(x)等于(?f/?T)(T(x))qθ(x)。其中,f為生成f-散度的函數,f*為f的凸共軛函數。
令w(x) = (?f/?T)(T*(x)),則上式可以改寫為:
p(x) = (w*(x))qθ(x)
這樣改寫后,很明顯了,這可以看成一個重要性采樣(importance sampling)問題。(重要性采樣是強化學習中推導策略梯度的常用方法。)相應地,w*(x)為最優重要性權重(importance weight)。
類似地,令w(x) = (?f*/?T)(T(x)),我們可以得到f-散度的重要性權重估計:
其中,β為分區函數:
使用重要性權重作為獎勵信號,可以得到基于KL散度的策略梯度:
然而,由于這一策略梯度需要估計分區函數β(比如,使用蒙特卡洛法),因此,方差通常會比較大。因此,論文作者基于歸一化的重要性權重降低了方差。
令
其中,gθ(x | z): Z -> [0, 1]d為條件密度,h(z)為z的先驗。
令分區函數
則歸一化的條件權重可定義為
由此,可以得到期望條件KL散度:
令x(m)~ gθ(x | z)為取自先驗的樣本,又令
為使用蒙特卡洛估計的歸一化重要性權重,則期望條件KL散度的梯度為:
如此,論文作者成功降低了梯度的方差。
此外,如果考慮逆KL散度的梯度,則我們有:
上式中,靜態網絡的輸出Fφ(x)可以視為獎勵(reward),b可以視為基線(baseline)。因此,論文作者將其稱為基于強化的BGAN。
試驗
離散
為了驗證BGAN在離散設定下的表現,論文作者首先試驗了在CIFAR-10上訓練一個分類器。結果表明,搭配不同f-散度的基于重要性取樣、強化的BGAN均取得了接近基線(交叉熵)的表現,大大超越了WGAN(權重裁剪)的表現。
在MNIST上的試驗表明,BGAN可以生成穩定、逼真的手寫數字:
在MNIST上與WGAN-GP(梯度懲罰)的比較顯示,采用多種距離衡量,包括Wasserstein距離,BGAN都取得了更優的表現:
在quantized版本的CelebA數據集上的表現:
左為降采樣至32x32的原圖,右為BGAN生成的圖片
下為隨機選取的在1-billion word數據集上訓練的BGAN上生成的文本的3個樣本:
雖然這個效果還比不上當前最先進的基于RNN的模型,但此前尚無基于GAN訓練離散值的模型能實現如此效果。
連續
論文作者試驗了BGAN在CelebA、ImageNet、LSUN數據集上的表現,均能生成逼真的圖像:
在CIFAR-10與原始GAN、使用代理損失(proxy loss)的DCGAN的比較表明,BGAN的表現和訓練穩定性都是最優的:
-
網絡
+關注
關注
14文章
7553瀏覽量
88729 -
GaN
+關注
關注
19文章
1933瀏覽量
73286 -
函數
+關注
關注
3文章
4327瀏覽量
62571
原文標題:BGAN:支持離散值、提升訓練穩定性的新GAN訓練方法
文章出處:【微信號:jqr_AI,微信公眾號:論智】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
相關推薦
評論