作者:Prajwal Shreyas
編譯:ronghuaiyang
導讀
對於列表型的數據,使用深度學習的方法來進行預處理的方法。
在本博客中,我將帶你通過keras上的深度學習網絡,了解創建分類變量嵌入的步驟。
傳統的嵌入
在我們使用的大多數數據源中,我們主要會遇到兩種類型的變量:
- 連續變量:這些通常是整數或小數,有無限個可能的值,例如計算機內存單元,比如1GB, 2GB等。
- 分類變量:這些是離散變量,用於根據某些特徵分割的數據。計算機內存的類型,如RAM內存、內部硬碟、外部硬碟等。
當我們構建一個ML模型時,通常需要對分類變量進行轉換,然後才能在算法中使用它。應用的轉換對模型的性能有很大的影響,特別是當數據具有大量類別的分類特徵時。一些常見的轉換的例子包括:
One-Hot編碼:這裡我們為每個類別值轉換為一個新列,並為該列分配一個' 1 '或' 0 ' (True/False)值。
二進位編碼:這樣創建的特性少於one-hot,同時保留列中值的一些唯一性。它能很好地處理高維有序數據。
然而,這些常用的轉換並沒有捕獲分類變量之間的關係。
數據
為了演示深度嵌入的應用,讓我們以Kaggle中的bike sharing數據為例。
我們可以看到數據集中有很多列。為了演示這個概念,我們只使用數據中的date_dt、cnt和mnth列。
傳統的one-hot編碼將產生12列,每個月1列。然而在這種類型的嵌入中,每周的每一天都同樣重要,每個月之間沒有關係。
在下面的圖表中,我們可以看到每個月的季節模式。我們可以看到第4到9個月是高峰期。第0、1、10、11個月是自行車租賃需求較低的幾個月。
此外,當我們繪製每個月的日常使用情況時,用不同的顏色表示,我們可以看到每個月中的一些每周模式。
理想情況下,我們希望通過使用嵌入來捕獲這種關係。在下一節中,我們將研究如何使用構建在keras之上的深度網絡來生成這些嵌入。
深度嵌入
代碼如下所示,我們將建立一個感知器網絡與dense層網絡和一個 『relu』 激活函數。
網絡的輸入,即' x '變量,為月號。這是一年中每個月的數字表示,範圍從0到11。因此,input_dim被設置為12。
網絡的輸出即' y '是縮放後的' cnt '值。也可以增加 『y』的維度以包含其他連續變量。在這裡,當我們使用一個連續變量時,我們將把最後輸出的dense層的節點設為1。我們將進行50個epoch來訓練模型。
embedding_size = 3
model = models.Sequential()
model.add(Embedding(input_dim = 12, output_dim = embedding_size, input_length = 1, name="embedding"))
model.add(Flatten())
model.add(Dense(50, activation="relu"))
model.add(Dense(15, activation="relu"))
model.add(Dense(1))
model.compile(loss = "mse", optimizer = "adam", metrics=["accuracy"])
model.fit(x = data_small_df['mnth'].as_matrix(), y=data_small_df['cnt_Scaled'].as_matrix() , epochs = 50, batch_size = 4)
模型參數
嵌入層:這裡我們為分類變量指定嵌入大小。在本例中是3,如果我們增加它,它將捕獲更多關於分類變量之間關係的細節。Jeremy Howard提出了以下選擇嵌入大小的解決方案:
# m is the no of categories per feature
embedding_size = min(50, m+1/ 2)
我們使用「adam「優化器,使用均方誤差損失函數。Adam比sgd(隨機梯度下降)更受歡迎,因為它具有更快的自適應學習速度。
結果
每個月的最終嵌入結果如下。這裡「0」代表一月,「11」代表十二月。
當我們用3D圖來觀察這一點時,我們可以清楚地看到月份之間的關係。具有相似「cnt」的月份被更緊密地分組在一起,例如第4個月到第9個月彼此非常相似。
總結
綜上所述,我們已經看到,通過使用Cat2Vec(分類變量到向量),我們可以使用低維嵌入來表示高基數的分類變量,同時保持每個類別之間的關係。
英文原文:https://towardsdatascience.com/deep-embeddings-for-categorical-variables-cat2vec-b05c8ab63ac0
請長按或掃描二維碼關注本公眾號