識別形式語言能力不足,不完美的Transformer要克服自注意力的理論缺陷

2022-04-13   CDA數據分析師

原標題:識別形式語言能力不足,不完美的Transformer要克服自注意力的理論缺陷

作者:David Chiang、Peter Cholak

機器之心編譯

本文轉自:機器之心

最近一兩年,transformer 已經在 NLP、CV 等多樣化任務上實現了卓越的性能,並有一統 AI 領域的趨勢。那麼,推出已近五年的注意力機制真的是所有人需要的嗎?近日,有論文檢驗了 transformer 在兩種形式語言上的理論缺陷,並且設計了方法克服這種缺陷。文章還研究了可能出現的長度泛化的問題,並提出了相應的解決方案。

最近一兩年,transformer 已經在 NLP、CV 等多樣化任務上實現了卓越的性能,並有一統 AI 領域的趨勢。那麼,推出已近五年的注意力機制真的是所有人需要的嗎?近日,有論文檢驗了 transformer 在兩種形式語言上的理論缺陷,並且設計了方法克服這種缺陷。文章還研究了可能出現的長度泛化的問題,並提出了相應的解決方案。

儘管 transformer 模型在許多任務中都非常有效,但它們對一些看起來異常簡單的形式語言卻難以應付。Hahn (2020) 提出一個引理 5),來試圖解釋這一現象。這個引理是:改變一個輸入符號只會將 transformer 的輸出改變 𝑂(1/𝑛),其中 𝑛 是輸入字符串的長度。

因此,對於接收(即判定某個字符串是否屬於某個特定語言)只取決於單個輸入符號的語言,transformer 可能會以很高的準確度接受或拒絕字符串。但是對於大的 𝑛,它必須以較低的置信度做出決策,即給接受字符串的機率略高於 ½,而拒絕字符串的機率略低於 ½。更準確地說,隨著 𝑛 的增加,交叉熵接近每個字符串 1 比特,這是最壞情況的可能值。

近期,在論文《Overcoming a Theoretical Limitation of Self-Attention》中,美國聖母大學的兩位研究者用以下兩個正則語言(PARITY 和 FIRST)來檢驗這種局限性。

Hahn 引理適用於 PARITY,因為網絡必須關注到字符串的所有符號,並且其中任何一個符號的變化都會改變正確答案。研究者同時選擇了 FIRST 作為引理適用的最簡單語言示例之一。它只需要注意第一個符號,但因為更改這個符號會改變正確答案,所以該引理仍然適用。

儘管該引理可能被解釋為是什麼限制了 transformer 識別這些語言的能力,但研究者展示了三種可以克服這種限制的方法。

首先,文章通過顯式構造表明,以高準確度識別任意長度的語言的 transformer 確實是存在的。研究者已經實現了這些結構並通過實驗驗證了它們。正如 Hahn 引理所預測的那樣,隨著輸入長度的增加,這個構建的 transformer 的交叉熵接近 1 比特(也就是,僅比隨機猜測好一點)。但文章也表明,通過添加層歸一化,交叉熵可以任意接近零,而與字符串長度無關。

研究者在實踐中還發現,正如 Bhattamishra 等人所指出的,transformer 無法學習 PARITY。也許更令人驚訝的是,在學習 FIRST 時,transformer 可能難以從較短的字符串泛化到較長的字符串。儘管這不是 Hahn 引理的邏輯上可以推出的結果,但它是 Hahn 引理預測行為的結果。幸運的是,這個問題可以通過簡單的修改來解決,即將注意力的 logit 乘以 log 𝑛。此修改還改進了機器翻譯中在長度方面的泛化能力。

論文地址:https://arxiv.org/pdf/2202.12172.pdf

精確解決方案

克服 Hahn 引理所暗示的缺點的第一種方法是通過顯式構造表明 transformer 可以以高精度識別出上述提到的兩種語言。

針對 PARITY 的前饋神經網絡(FFNN)

Rumelhart 等人表明,對於任何長度𝑛都有一個前饋神經網絡 (FFNN) 可以計算長度正好為 𝑛 的字符串的 PARITY。他們還表明,隨機初始化的 FFNN 可以自動學習這麼做。

由於文章所提出構建方式部分基於他們的,因此詳細回顧他們的構建可能會有所幫助。設𝑤為輸入字符串,|𝑤| = 𝑛,𝑘是𝑤中 1 的個數。輸入是一個向量 x,使得 x_𝑖 = I[𝑤_𝑖 = 1]。第一層計算 𝑘 並將其與 1,2,...,n 進行比較:

因此,

第二層將奇數元素相加並減去偶數元素:

針對 PARITY 的 transformer

命題 1. 存在一個帶有 sigmoid 輸出層的 transformer,它可以識別(在上述意義上)任意長度字符串的 PARITY 語言。

最初,研究者將構造一個沒有層歸一化的 transformer 編碼器(即 LN(x) = x);然後展示如何添加層標準化。設 𝑘 是 1 在 𝑤 中出現的次數。網絡計算的所有向量都有 𝑑 = 9 維;如果顯示出較少的維度,則假設剩餘的維度為零。

詞和位置嵌入是:

研究者認為,位置編碼的第五維使用餘弦波是一個相當標準的選擇,儘管它的周期 (2) 比標準正弦編碼中的最短周期 (2𝜋) 短。第四維度誠然不是標準的;但是,研究者認為這依然是一種合理的編碼,並且非常容易計算。因此,單詞𝑤_𝑖的編碼是:

第二個 head 不做任何事情(W^V,1,2 = 0;query 和 key 可以是任何東西)。在殘差連接之後,可以得到:

在 Rumelhart 等人的構造中,下一步是使用階躍激活函數為每個 𝑖 計算 I[𝑖 ≤ 𝑘]。文章提出的構造有兩個不同之處。首先,激活函數採用 ReLU,而不是階躍激活函數。其次,因為注意力總和必須為 1,如果 𝑛是奇數,那麼偶數和奇數位置將獲得不同的注意力權重,因此奇數位置減去偶數位置的技巧將不起作用。相反,我們想要計算 I[𝑖 = 𝑘](如下圖 1)。

第一個 FFNN 有兩層,第一層是:

由此可以得出:

第二層採用線性的方式把這三個值結合在一起得到想要的 I[𝑖 = 𝑘]。

第二個自注意力層測試位置𝑘是偶數還是奇數。它使用兩個 head 來做到這一點,一個更強烈地關注奇數位置,一個更強烈地關注偶數位置;兩者的平均維度大小為 8:

針對 FIRST 的 transformer

接下來,研究者為 FIRST 構建一個 transformer。根據學習每個位置詞嵌入的常見做法(Gehring 等人,2017 年),他們使用位置編碼來測試一個詞是否在第 1 個位置 :

第一層 FFNN 計算一個新的組件(5)來測試是否 i = 1 以及 w_1 = 1。

第二個自注意力層只有一個單一的 head,這使得 CLS 關注於位置 1.

第二層 FFNN 什麼都不做(W^F,2,1 = b^F,2,1 = W^F,2,2 = b^F,2,2 = 0)。所以在 CLS 處(位置 0 處):

最後輸出層僅僅選擇組件 6。

實驗

文章使用 PyTorch 的 transformer 內置實現的修改版本(Paszke 等人,2019)實現了上述兩種建構。這些構造對長度從 [1, 1000] 採樣的字符串實現了完美的準確性。

然而,在下圖 2 中,紅色曲線(「沒有做層歸一化」)表明,隨著字符串變長,交叉熵接近每個字符串 1 比特的最壞可能值。

層歸一化

減輕或消除 Hahn 引理限制的第二種方法是層歸一化 (Ba et al., 2016),對於任何向量 x,其定義為

實驗中,𝛽 = 0 和𝛾 = 1,因此結果的均值近似為零和方差近似為 1。常數 𝜖 沒有出現在原始定義中(Ba et al., 2016),但為了數值穩定性,我們知道的所有實現中都添加了常數𝜖。

原始的 transformer 在每個殘差連接後立即執行層歸一化。在本節中,研究者修改了上面的兩個結構的層歸一化。,這一修改有兩個步驟。

去除中心

第一個是通過使網絡計算每個值𝑥以及 -𝑥來消除層歸一化的中心效應。新單詞的編碼是根據原始結構中的編碼定義的:

對於自注意力層的參數也是類似。

對於每個位置的 FFNN 參數也類似。

之後,每層的激活值為:

LN 的參數始終具有零均值,因此層標準化不會增加或減少任何內容。它確實縮放了激活,但在上面構建的兩個 transformer 的情形中,任何激活層都可以按任何正數進行縮放,而不會改變最終決策。

減少交叉熵

此外,在任何轉換器中,我們可以在任意 transformer 中使用層歸一化來將交叉熵縮小到想要的任意小,這與 Hahn 的引理 5 相反。在 Hahn 的公式中,像層歸一化這樣的位置相關的函數可以包含在他的 𝑓^act 中,但是 引理假設 𝑓^act 是 Lipschitz 連續的,而 ϵ = 0 的層歸一化不是。

命題 2. 對於任何具有層歸一化 (ϵ = 0) 並可以識別語言 L 的 transformer 𝑇,對任何 𝜂 > 0 而言,都存在一個可以以最多𝜂為交叉熵的、帶有層歸一化的識別語言 L 的 transformer。

證明。讓𝑑表示原始激活向量中的維數,𝐿是層數。然後添加一個新層,這個層中的自注意力不做任何事情 (W^V,𝐿+1,ℎ = 0),並且 FFNN 是根據原始輸出層定義的:

這會導致殘差連接除了 2 個維度外的所有維度為零,因此如果𝑠是原始輸出 logit,則此新層的輸出(層歸一化之前)為

現在,如果 ϵ = 0,層歸一化將該向量縮放到只具有單位方差,因此它變為:

新的輸出層只是簡單地選擇第一維,並把它拓展到 c 倍。

實驗

研究者測試了這一解決方案,它在層歸一化時進行了如上修改。上圖 2 顯示 ϵ > 0 的層歸一化提高了交叉熵,但它仍然隨著 𝑛 增長並接近 1。

可學習性

在本節中,研究者將轉向可學習性的問題,這時克服 Hahn 引理所提出的缺陷的第三種方法。

實驗:標準的 transformer

研究者在 PARITY 和 FIRST 上訓練 transformer。每個 transformer 都具有與對應的精確解相同的層數和頭數以及相同的固定位置編碼。與單詞編碼、自注意力相關的 FFNN 輸出的𝑑_model 為 16,而與 FFNN 隱藏層相關的𝑑_FFNN 為 64。殘差連接之後使用了層歸一化(ϵ = 10^-5)。

實驗使用 PyTorch 默認的初始化並使用 Adam (Kingma and Ba, 2015) 進行訓練,學習率為 3 × 10^−4 (Karpathy, 2016)。實驗沒有使用 dropout,因為它似乎沒有幫助。

FIRST 更容易學習,但壞消息是學習到的 transformer 不能很好地泛化到更長的句子。下圖 4(左列)顯示,當 transformer 在較短的字符串(𝑛 = 10、30、100、300)上從頭開始訓練並在較長的字符串(𝑛 = 1000)上進行測試時,準確度並不非常好。事實上,對於𝑛 = 10 上的訓練,準確度和隨機猜測類似。

長度取對數後拓展的注意力

幸運的是,這個問題很容易通過以 log 𝑛的比例縮放每個注意力層的 logits 來解決,即將注意力重新定義為

然後使用 c = 1 下的針對 FIRST 的 Flawed transformer:

命題 3。對於任何 𝜂 > 0,都存在一個帶有如公式 2 所定義的注意力的 transformer,它無論有無層歸一化,都可以以最多𝜂的交叉熵來識別 FIRST 語言。

證明。當沒有層歸一化時,3.3 節中描繪的模型中 c 設為 1,並對注意力的權重進行對數尺度的縮放,它可以將公式(1)中的 s 從公式(1)轉化為:

實驗:縮放的注意力

下圖 4(右欄)的 tranformer 模型使用了 log n 為縮放因子的注意力。

我們可以在稀缺資源英語 - 越南語的機器翻譯任務上使用開源 transformer 模型(即 Witwicky)時看到相似的效應(如下表 1 所示)。

當訓練集和測試即得長度分布一樣得時候,縮放注意力的 logits 沒有顯著的影響,但如果僅在只有中等甚至更短(小於 20)長度的句子上訓練,而測試句子長度大於中等長度(大於 20),縮放注意力則提高了 1 個 BLEU 分數,這在統計上已經很顯著了(p 值小於 0.01)。

點這裡關注我,記得標星哦~

CDA課程諮詢