解決GAN模式崩潰的兩條思路:改進優化和網絡架構

2019-10-16     有三AI

今天講述的內容仍然是GAN中的模式崩潰問題,首先將說明模式崩潰問題的本質,並介紹兩種解決模式崩潰問題的思路,然後將介紹一種簡單而有效的解決方案MAD-GAN,最後一部分將給出MAD-GAN的強化版本MAD-GAN-Sim。

作者 | 小米粥

編輯 | 言有三

1. 解決模式崩潰的兩條路線

GAN的模式崩潰問題,本質上還是GAN的訓練優化問題,理論上說,如果GAN可以收斂到最優的納什均衡點,那模式崩潰的問題便自然得到解決。舉例如下圖,紅線代表生成數據的機率密度函數,而藍線代表訓練數據集的機率密度函數,本來紅線只有一個模式,也就是生成器幾乎只會產生一種樣本,而在理論上的最優解中,紅線與藍線重合,這時候在生成器中採樣自然能幾乎得到三種樣本,與訓練集的數據表現為一致。

當然,實際中幾乎不會達到全局最優解,我們看似收斂的GAN其實只是進入了一個局部最優解。故一般而言,我們有兩條思路解決模式崩潰問題:

1.提升GAN的學習能力,進入更好的局部最優解,如下圖所示,通過訓練紅線慢慢向藍線的形狀、大小靠攏,比較好的局部最優自然會有更多的模式,直覺上可以一定程度減輕模式崩潰的問題。

例如上一期unrolled GAN,便是增加了生成器「先知」能力;

2.放棄尋找更優的解,只在GAN的基礎上,顯式地要求GAN捕捉更多的模式(如下圖所示),雖然紅線與藍線的相似度並不高,但是「強制」增添了生成樣本的多樣性,而這類方法大都直接修改GAN的結構。

2. MAD-GAN

今天要介紹的MAD-GAN及其變體便是第二類方法的代表之一。

它的核心思想是這樣的:即使單個生成器會產生模式崩潰的問題,但是如果同時構造多個生成器,且讓每個生成器產生不同的模式,則這樣的多生成器結合起來也可以保證產生的樣本具有多樣性,如下圖的3個生成器:

需要說明一下,簡單得添加幾個彼此孤立的生成器並無太大意義,它們可能會歸併成相同的狀態,對增添多樣性並無益處,例如下圖的3個生成器:

理想的狀態是:多個生成器彼此「聯繫」,不同的生成器儘量產生不相似的樣本,而且都能欺騙判別器。

在MAD(Multi-agent diverse)GAN中,共包括k個初始值不同的生成器和1個判別器,與標準GAN的生成器一樣,每個生成器的目的仍然是產生虛假樣本試圖欺騙判別器。對於判別器,它不僅需要分辨樣本來自於訓練數據集還是其中的某個生成器(這仍然與標準GAN的判別器一樣),而且還需要驅使各個生成器儘量產生不相似的樣本。

需要將判別器做一些修改:將判別器最後一層改為k+1維的softmax函數,對於任意輸入樣本x,D(x)為k+1維向量,其中前k維依次表示樣本x來自前k個生成器的機率,第k+1維表示樣本x來自訓練數據集的機率。同時,構造k+1維的delta函數作為標籤,如果x來自第i個生成器,則delta函數的第i維為1,其餘為0,若x來自訓練數據集,則delta函數的第k+1維為1,其餘為0。顯然,D的目標函數應為最小化D(x)與delta函數的交叉熵:

直觀上看,這樣的損失函數會迫使每個x儘量只產生於其中的某一個生成器,而不從其他的生成器中產生,將其展開則為:

生成器目標函數為:

對於固定的生成器,最優判別器為:

可以看出,其形式幾乎同標準形式的GAN相同,只是不同生成器之間彼此「排斥」產生不同的樣本。另外,可以證明當

達到最優解,再一次可以看出,MAD-GAN中並不需要每個生成器的生成樣本機率密度函數逼近訓練集的機率密度函數,每個生成器都分別負責生成不同的樣本,只須保證生成器的平均機率密度函數等於訓練集的機率密度函數即可。

3. MAD-GAN-Sim

MAD-GAN-Sim是一種「更強力」的版本,它不僅考慮了每個生成器都分別負責生成不同的樣本,而且更細緻地考慮了樣本的相似性問題。其出發點在於:來自於不同模式的樣本應該是看起來不同的,故不同的生成器應該生成看起來不相似的樣本。

這一想法用數學符號描述即為:

其中φ (x)表示從生成樣本的空間到特徵空間的某種映射(我們可選擇生成器的中間層,其思想類似於特徵值匹配),Δ (x,y)表示相似度的度量,多選用餘弦相似度函數,用於計算兩個樣本對應的特徵的相似度。

對於給定的噪聲輸入z,考慮第i個生成器與其他生成器的樣本生成情況,若樣本相似度比較大,則D(G_i(z))相比較D(G_j(z))應該大很多,由於D(G_j(z))的值比較小,G_j(z)便會進行調整不再生成之前的那個相似的樣本,轉而去生成其他樣本,利用這種「排斥」機制,我們就實現了讓不同的生成器應該生成看起來不相似的樣本。

將上述限制條件引入到生成器中,我們可以這樣訓練生成器,對於任意生成器i,對於給定的z,如果上面的條件滿足,則像MAD-GAN一樣正常計算,其梯度為:

如果條件不滿足,將上述條件作為正則項添加到目標函數中,則其梯度為:

這樣儘量使得判別器更新後,條件能夠滿足。MAD-GAN-Sim的思路非常直接清晰,不過代價就是增加非常多的計算量。

[1]Ghosh A , Kulharia V , Namboodiri V , et al. Multi-Agent Diverse Generative Adversarial Networks[J]. 2017.

總結

今天首先說明了模式崩潰問題的本質,並介紹兩種解決模式崩潰問題的思路,然後介紹一種簡單而有效的解決方案MAD-GAN及其強化版本MAD-GAN-Sim。

文章來源: https://twgreatdaily.com/zh/FaFP120BMH2_cNUg1T8Y.html