ICLR 2020 | "同步平均教學"框架為無監督學習提供更魯棒的偽標籤

2020-03-30     AI科技評論

本文介紹一篇由港中文發表於ICLR-2020的論文《Mutual Mean-Teaching: Pseudo Label Refinery for Unsupervised Domain Adaptation on Person Re-identification》[1],其旨在解決更實際的開放集無監督領域自適應問題,所謂開放集指預先無法獲知目標域所含的類別。這項工作在多個行人重識別任務上驗證其有效性,精度顯著地超過最先進技術13%-18%,大幅度逼近有監督學習性能。這也是ICLR收錄的第一篇行人重識別任務相關的論文,代碼和模型均已公開。


文 | 葛藝瀟

編 | 賈偉




論文連結:https://openreview.net/forum?id=rJlnOhVYPS

代碼連結:https://github.com/yxgeee/MMT

1、背景介紹


1.1、任務

行人重識別(Person ReID)旨在跨相機下檢索出特定行人的圖像,被廣泛應用於監控場景。如今許多帶有人工標註的大規模數據集推動了這項任務的快速發展,也為這項任務帶來了精度上質的提升。

然而,在實際應用中,即使是用大規模數據集訓練好的模型,若直接部署於一個新的監控系統,顯著的領域差異通常會導致明顯的精度下降。在每個監控系統上都重新進行數據採集和人工標註由於太過費時費力,也很難實現。

所以無監督領域自適應(Unsupervised Domain Adaptation)的任務被提出以解決上述問題,讓在有標註的源域(Source Domain)上訓練好的模型適應於無標註的目標域(Target Domain),以獲得在目標域上檢索精度的提升。值得注意的是,有別於一般的無監督領域自適應問題(目標域與源域共享類別),行人重識別的任務中目標域的類別數無法預知,且通常與源域沒有重複,這裡稱之為開放集(Open-set)的無監督領域自適應任務,該任務更為實際,也更具挑戰性。


1.2、動機

無監督領域自適應在行人重識別上的現有技術方案主要分為基於聚類的偽標籤法、領域轉換法、基於圖像或特徵相似度的偽標籤法,其中基於聚類的偽標籤法被證實較為有效,且保持目前最先進的精度 [2,3],所以該論文主要圍繞該類方法進行展開。基於聚類的偽標籤法,顧名思義,

(i)首先用聚類算法(K-Means, DBSCAN等)對無標籤的目標域圖像特徵進行聚類,從而生成偽標籤,

(ii)再用該偽標籤監督網絡在目標域上的學習。以上兩步循環直至收斂,如下圖所示:



儘管該類方法可以一定程度上隨著模型的優化改善偽標籤質量,但是模型的訓練往往被無法避免的偽標籤噪聲所干擾,並且在初始偽標籤噪聲較大的情況下,模型有較大的崩潰風險。所謂偽標籤噪聲主要來自於源域預訓練的網絡在目標域上有限的表現力、未知的目標域類別數、聚類算法本身的局限性等等。所以如何處理偽標籤噪聲對網絡最終的性能產生了至關重要的影響,但現有方案並沒有有效地解決它。


2、解決方法


2.1、概述

為了有效地解決基於聚類的算法中的偽標籤噪聲的問題,該文提出利用"同步平均教學"框架進行偽標籤優化,核心思想是利用更為魯棒的"軟"標籤對偽標籤進行在線優化。在這裡,"硬"標籤指代置信度為100%的標籤,如常用的one-hot標籤[0,1,0,0],而"軟"標籤指代置信度<100%的標籤,如[0.1,0.6,0.2,0.1]。



如上圖所示,A1與A2為同一類,外貌相似的B實際為另一類,由於姿態多樣性,聚類算法產生的偽標籤錯誤地將A1與B分為一類,而將A1與A2分為不同類,使用錯誤的偽標籤進行訓練會造成誤差的不斷放大。該文指出,網絡由於具備學習和捕獲數據分布的能力,所以網絡的輸出本身就可以作為一種有效的監督。然而,利用網絡的輸出來訓練自己是不可取的,會無法避免地造成誤差的放大。所以該文提出同步訓練對稱的網絡,在協同訓練下達到相互監督的效果,從而避免對網絡自身的輸出誤差形成過擬合。在實際操作中,該文利用"平均模型"進行監督,提供更為可信和穩定的"軟"標籤,將在下文進行描述。總的來說,該文

  • 提出"同步平均教學"(Mutual Mean-Teaching)框架為無監督領域自適應的任務提供更為可信的、魯棒的偽標籤;
  • 針對三元組(Triplet)設計合理的偽標籤以及匹配的損失函數,以支持協同訓練的框架。

2.2、同步平均教學



如上圖所示,該文提出的"同步平均教學"框架利用離線優化的"硬"偽標籤與在線優化的"軟"偽標籤進行聯合訓練。"硬"偽標籤由聚類生成,在每個訓練epoch前進行單獨更新;"軟"偽標籤由協同訓練的網絡生成,隨著網絡的更新被在線優化。直觀地來說,該框架利用同行網絡(Peer Networks)的輸出來減輕偽標籤中的噪聲,並利用該輸出的互補性來優化彼此。而為了增強該互補性,主要採取以下措施:

  • 對兩個網絡Net 1和Net 2使用不同的初始化參數;
  • 隨機產生不同干擾,例如,對輸入兩個網絡的圖像採用不同的隨機增強方式,如隨機裁剪、隨機翻轉、隨機擦除等,對兩個網絡的輸出特徵採用隨機dropout;
  • 訓練Net 1和Net 2時採用不同的"軟"監督,i.e. "軟"標籤來自對方網絡的"平均模型";
  • 採用網絡的"平均模型"Mean-Net 1/2而不是當前的網絡本身Net 1/2進行相互監督。

此處,"平均模型"的參數是對應網絡參數的累計平均值,具體來說,"平均模型"的參數不是由損失函數的反向傳播來進行更新的,而是在每次損失函數的反向傳播後,利用以下公式將對應的網絡參數以進行加權平均:

這裡,指第個iteration,和分別為Net 1和Net 2的當前參數。在初始化時,,。故"平均模型"可以看作對網絡過去的參數進行平均,兩個"平均模型"由於具有時間上的累積,解耦性更強,輸出更加獨立和互補。有一種簡單的協同學習方案是將此處的"平均模型"去除,直接使用網絡自己的輸出去監督對稱的網絡,如利用Net 1的輸出去監督Net 2。而在這樣的方案下存在兩點弊端,(1)由於網絡本身靠反向傳播參數更新較快,受噪聲影響更嚴重,所以用這樣不穩定的監督容易對網絡的學習造成影響,文章4.4的消融學習中進行了比較,(2)該簡化方案讓網絡直接訓練逼近彼此,會使得網絡迅速收斂至相似,降低輸出的互補性,文章附錄A.1中進行了詳細說明。值得注意的是,由於"平均模型"不會進行反向傳播,所以不需要計算和存儲梯度,並不會大規模增加顯存和計算複雜度。在測試時,只使用其中一個網絡進行推理,相比較baseline,不會增加測試時的計算複雜度。

在行人重識別任務中,通常使用分類損失與三元損失進行聯合訓練以達到較好的精度。其中分類損失作用於分類器的預測值,而三元損失直接作用於圖像特徵。為了方便展示,下文中,我們使用指代編碼器,指代分類器,每個Net都由一個編碼器和一個分類器組成,我們用角標,來區分Net 1和Net 2。我們使用角標,來區分源域和目標域,源域圖像及其標籤被表示為 ,目標域無標註的圖像表示為。

2.3、"軟"分類損失

利用"硬"偽標籤進行監督時,分類損失可以用一般的多分類交叉熵損失函數來表示:

上式中,為目標域圖像的"硬"偽標籤,由聚類產生。在"同步平均教學"框架中,"軟"分類損失中的"軟"偽標籤是"平均模型"Mean-Net 1/2的分類預測值。針對分類預測,很容易想到利用"軟"交叉熵損失函數來進行監督,該損失函數被廣泛應用於模型蒸餾,用以減小兩個分布間的距離:


上式中和表示同一張圖像經過不同的隨機數據增強方式。該式旨在讓Net 1的分類預測值逼近Mean-Net 2的分類預測值,讓Net 2的分類預測值逼近Mean-Net 1的分類預測值。


2.4、"軟"三元損失

傳統的三元(anchor, positive, negative)損失函數表示為:

上式中表示歐氏距離,下角標和分別表示的正樣本和負樣本,是餘量超參。這裡,正負樣本由聚類產生的偽標籤判斷,所以該式可以用以支持"硬"偽標籤的訓練。但是,不足以支持軟標籤的訓練,減法形式的三元損失也無法直觀地提供軟標籤。這裡的難點在於,如何在三元組的圖像特徵基礎上設計合理的"軟"偽標籤,以及如何設計對應的"軟"三元損失函數。該文提出使用softmax-triplet來表示三元組內特徵間的關係,表示為:

這裡softmax-triplet的取值範圍為[,可以用來替換傳統的三元損失,當使用"硬"偽標籤進行監督時,可以看作二分類問題,使用二元交叉熵損失函數進行訓練:

這裡的""指的是每個樣本與其負樣本的歐氏距離應該遠遠大於與正樣本的歐氏距離。但由於偽標籤存在噪聲,並不能完全正確地區分正負樣本,所以該文提出需要軟化對三元組的監督(使用"平均模型"輸出的特徵距離比代替硬標籤"1",軟化後標籤取值範圍在 之間)。具體來說,在"同步平均教學"框架中,"平均模型"編碼的圖像特徵計算出的softmax-triplet可用作"軟"偽標籤以監督三元組的訓練:

該損失函數旨在讓Net 1輸出的softmax-triplet逼近Mean-Net 2的softmax-triplet預測值,讓Net 2輸出的softmax-triplet逼近Mean-Net 1的softmax-triplet預測值。通過該損失函數的設計,該文有效地解決了傳統三元損失函數無法支持"軟"標籤訓練的局限性。"軟"三元損失函數可以有效提升無監督領域自適應在行人重識別任務中的精度,實驗詳情參見原論文消融學習的對比實驗。

2.5、算法流程


該文提出的"同步平均教學"框架利用"硬"/"軟"分類損失和"硬"/"軟"三元損失聯合訓練,在每個訓練iteration中,主要由三步組成:

  1. 通過"平均模型"計算分類預測和三元組特徵的"軟"偽標籤;
  2. 通過損失函數的反向傳播更新Net 1和Net 2的參數;
  3. 通過參數加權平均法更新Mean-Net 1和Mean-Net 2的參數。


3、實驗結果


該文在四個行人重識別任務上進行了驗證,精度均比現有最先進的方法 [2,3] 提升十個點以上,即將媲美有監督學習的性能。論文中使用K-Means聚類進行實驗,在每個行人重識別任務中都對不同的偽類別數(表格中表示為MMT-偽類別數)進行了驗證。發現無需設定特定的數目,均可獲得最先進的結果。另外,開源的代碼中包含了基於DBSCAN的實驗腳本,可以進一步提升性能,感興趣的同學可以嘗試。論文中的消融研究有效證明了"同步平均教學"框架的設計有效性和可解釋性,在這裡就不細細展開了。

4、總結


該文針對基於聚類的無監督領域自適應方法中無法避免的偽標籤噪聲問題展開了研究,提出使用"同步平均教學"框架在線生成並優化更為魯棒和可信的"軟"偽標籤,並設計了針對三元組的合理偽標籤以及對應的損失函數,在四個行人重識別任務中獲得超出最先進算法13%-18%的精度。


參考資料:

[1] Y. Ge, et al. Mutual Mean-Teaching: Pseudo Label Refinery for Unsupervised Domain Adaptation on Person Re-identification. ICLR, 2020.

[2] X. Zhang, et al. Self-training with progressive augmentation for unsupervised cross-domain person re-identification. ICCV, 2019.

[3] F. Yang, et al. Self-similarity grouping: A simple unsupervised cross domain adaptation approach for person re-identification. ICCV, 2019.

文章來源: https://twgreatdaily.com/zh-cn/IfplK3EBfwtFQPkdDKl0.html