初探Factorization Machine(II)

YL-Tsai
14 min readJan 24, 2021

探索如何透過Tensorflow實作Factorization Machine

前言

在上一篇初探Factorization_Machine(I)中,筆者帶著大家對Factorization Machine(以下簡稱FM)有了理論上的認識,本篇筆者會透過兩組原始碼來分析如何透過tensorflow來實作FM,並和各種模型在MNIST以及Movielens上跑分,最後也會使用MNIST以及Movielens比較模型準確度、訓練時間、推論速度等,那麼讓我們開始吧!

原始碼Survey

FM實際上在2010年就發表了,之後也有各種改進版本,例如任意階乘的Interaction,找原始碼的過程中會找到各種不同的實作,以下列出筆者survey的候選清單:

筆者自製圖表

由於筆者首次認識FM,以理解實作原理為重,實作上看起來比較好懂的為主要考量,工程實現上為了可以更快,有查到蠻多Cython甚至C/C++的實作,以Python作為Wrapper來操控,就不是本次原始碼分析的範圍,本篇原始碼重點會放在:

  1. FM的latent factor如何實作與訓練
  2. 以FM的模型設計邏輯為核心,解構FMRegressor/FMClassifier與FMCore的繼承關係
  3. 該份原始碼的修改彈性
  4. FM與其他模型的比較(準確度、推論速度、訓練時間等)
  5. FM以Dense Matrix / Sparse Matrix輸入的表現行為

並分析tffm以及tensor-fm兩個函式庫

剖析的程式碼也會一並上傳到這裡供讀者們參考

FM的實作與訓練

在上一篇初探Factorization_Machine(I)中,論文推導了公式來表達如何對Feature Interaction做latent factor的萃取,公式如下:

paper

那麼最後這一行實作部分的程式碼長怎樣呢?

tensor-fm給了我們一個相當簡易的版本,在tensor-fm的base.py中,可以看到一個定義為 fm 的函數:

https://github.com/gmodena/tensor-fm

看起來還算挺直白的,幾乎就是方程式直接翻譯過來(或許tf2真的開始提供一些方便性了…)

Interaction term : tensor X 和 tensor V 內積 並平方 減去 tensor X 平方 和 tensor V 平方

怎麼訓練呢?

https://github.com/gmodena/tensor-fm
https://github.com/gmodena/tensor-fm

我們可以看到先進行初始化,並透過max_iter來進行迭代,loss function的部分則考慮了cross-entropy以及MSE

這裡我們也可以看到,該份原始碼並沒有對論文中所提及的Sparse Input進行實作

不太記得論文中對Sparse Input可以訓練更快的分析? 參考初探Factorization Machine(I) — The computational trick小節

Regularization的部分則是實作了L1以及L2

https://github.com/gmodena/tensor-fm

並且在 sklearn.py 中,該作者透過繼承 BaseEstimator 來實作 BaseFactorizationMachine

https://github.com/gmodena/tensor-fm

接著使用 BaseFactorizationMachine 以及 RegressorMixin 建構出 Regressor ,搭配 ClassifierMixin 建構出 Classifier

Regressor :

https://github.com/gmodena/tensor-fm
https://github.com/gmodena/tensor-fm

Classifier :

https://github.com/gmodena/tensor-fm
https://github.com/gmodena/tensor-fm

以上我們也可以看到optimizer的部分是寫死使用 Adam ,classifier只支援二元分類等( LabelBinarizer )

以上可以看到透過tf2以及sklearn來實作,如實可以產生相當好的可讀性,也協助了我們理解latent factor建構,訓練的精髓

修改彈性方面,若要修改loss function/regularization function/sparse input能夠直接被base.py進行修改,optimizer的部分則會在sklearn.py進行修改。

任意階乘

tffm則是拓展到了任意階乘,但由於使用tensorflow 1.8實作,計算圖的部分必須自己管理,就會寫更多的程式碼,可以彈性修改的地方也就越多,筆者使用了SOURCE TRAIL來進行程式碼分析,主要是除了使用IDE快速跳轉之外,透過視覺化介面也有助於掌握整體架構:

https://github.com/geffy/tffm

原始碼架構上有4個類別,從圖來看繼承關係是

TFFMCore -> TFFMBaseModel -> TFFMClassifier / TFFMRegressor

FMClassifier / FMRegressor

https://github.com/geffy/tffm
https://github.com/geffy/tffm

FMClassifier / FMRegressor 處在最高階的封裝,

FMRegressor僅提供了fit, predict,而loss function只支援了mse

FMClassifier除了支援fit, predict之外,也支援了predict_prob以及sample weight上的調整,loss function則是logistic loss

FMBaseModel

https://github.com/geffy/tffm

FMBaseModel處在第2階段封裝,定義了較多只要是個模型大致上都需要的功能,像是

是個模型可能都需要的 : batch_size, n_epochs, sample_weight, verbose

模型狀態儲存及tensorflow session生命週期管理 : save_state, load_state, session, session_config, destroy

模型監控相關的(tensorboard-related) : summary_writer, need_logs, log_dir

並且希望只要是個model,就一定要有predict方法,intercept以及weights設計成唯讀的形式:

https://github.com/geffy/tffm

主要訓練過程也在這個類別被定義

FMCore

https://github.com/geffy/tffm

FMCore處在第一階封裝,定義了計算圖的初始化,怎麼跑計算圖,計算圖的長相等,也定義了演算法實作細節

演算法實作細節 :

input_type : (dense or sparse)

order : 特徵交叉項要做到幾階(對,這份實作支援任意階數的特徵交叉)

rank : latent factor的數量

至於其他比較常見的內容像是 : n_features, optimizer, loss_function, reduce_loss, loss

其他用tensorflow.name_scope包起來的方法和變量 : init_placeholders, init_main_block, init_regularization, init_target, …

右排的wrapper則是針對desne data 以及 sparse data做了不一樣的計算方式,sparse data並不需要整個矩陣都存下來,並且可以只針對index做運算,其他的都是0所以可以不用算

只要存shape以及有值的index,該數值是什麼,可以看到有count_nonzero, matmul, pow等方法

最重要的計算圖在 main_block 當中:

https://github.com/geffy/tffm

上面計算圖是High-Order的計算圖,可以看到對於order以及power特進行了處理

小結來說,tffm同樣地我們也可以根據我們的需要修改loss function,optimizer,和tensor-fm比起來,tffm多實作了以下功能:

  1. 分類問題的sampling_weights
  2. Sparse input的優化
  3. 模型儲存

經過以上分析,就能夠了解該程式碼所能2次開發的彈性,再按照自己的開發能力去推算某項功能的開發時間

模型實測

這裡筆者使用了tffm來進行實測,主要是因為可以對Sparse/Dense input進行比較的緣故(而不需要多開發)

準確度/訓練時間/推論時間分析

MNIST

許多讀者會傻眼的地方是,為什麼找MNIST而不是Movielens呢?

事實上當然可以找Movielens,不過也能夠找MNIST,雖然tffm的範例直接就是MNIST,讓開發者可以快速地進行比較,不過筆者反而更喜歡拿論文模型去跑非論文上的benchmark,這麼做可以有以下好處:

  1. 讓我們卸下paper上好厲害好棒棒的SOTA濾鏡,進而了解該模型普遍來說的性能究竟如何
  2. 了解模型是否是對該資料集的測試集過擬合了,同樣地回到第一點,了解該模型Generalization的性能

但其實MNIST資料集也不是亂找的,FM的設計是給Sparse data,MNIST把28x28x1變成781x1,確實也是sparse的,當然可以更量化地來看這件事,以數字3以及數字5為例

資料集 : MNIST, 數字3以及數字5

筆者自製表格

讀者們可以注意到筆者除了選常用有名的模型之外,也特地把SVM — order 2選出來比,主要是想要驗證paper所說的,SVM — order 2是不是真的訓練不起來?

模型分析(準確度、訓練時間、推論速度)

筆者自製圖表

以上實測是在筆者的MacPro RAM 16G 4 Core上測試,sklearn以及lightgbm模型都是 n_job=-1 ,並且推論速度是經過warm start之後再跑50次平均,根據以上的數據可以發現:

  1. 準確度可以看到FM order = 3的準確度最高,接下來則是LGBM
  2. SVM order = 2在這個資料集是訓練的起來的(但需要訓練很久),Paper說的訓練不起來的情況,應該會滿足某些標準,例如Non-Zero ratio達到多少以下
  3. tffm的sparse input實作上可能有一些問題,在我的mac上跑訓練和推論都是dense比較快,這與理論上不太符合,sparse input應該要可以比dense input來得快才是
  4. FM order = 2 及 FM order = 3 如實能夠做到自動化特徵工程以及還不錯的訓練速度/推論速度,這一點也是Paper的價值所在
  5. LR 可以有最快的訓練速度以及推論速度,單個樣本的推論速度在我的mac上可以跑到0.06毫秒,實際應用場景可以LR, FM都訓練,接著再評估場景需求,要快且準度還能接受,就使用LR,可以接受稍微慢一些的推論速度但更準確,則使用FM,推論速度和準確度都可以納入商業指標中計算期望值
  6. 若是需要準確度和訓練速度/推論速度上的權衡,LGBM也可以加入作為實驗對象,LGBM表現也不錯,且推論時間也不慢(但MNIST資料集應該是難度相對簡單)
  7. 用tensorflow寫的好處是可以透過GPU在加速訓練及推論,相關benchmark可以看這裡

特徵數量 — 時間複雜度分析

筆者自製圖表

FM — order = 2

latent factors = 100

inference time samples = 50

errbar : std

  1. 無論dense/sparse,推論時間複雜度大致上與特徵數量呈現線性關係
  2. 再一次地看到sparse推論時間比dense還要久,這與理論情況不符合,或許工程實作上有些地方存在優化空間

Movielens

當然,經典的Movielens資料集還是可以作為一個比較有感覺的benchmark,由於推論速度已經比較過,因此這裏就僅列出整體表現的表格,Regression指標除了MSE以外,也加入了MAPE(平均百分誤差),用於了解模型預測平均來說,與真實值的差距比例

資料集 : Movielens(ml-100k), 目標 : 電影評分

筆者自製表格
筆者自製表格

從以上分析可以發現 :

  1. 資料集特性上可以看到目標不平均的狀況並不會到很糟,1分的資料數量約是4分數量的1/5,就筆者的經驗上來說,這個不平衡的程度不到很大,從不平衡狀態也可以預期3~5分的資料點會預測的較為準確,1~2分的資料點準確度則會稍差點
  2. Non-Zero rate降到了0.07%,比起MNIST近乎只剩下1/100,有趣的是Support Vector Regression還是可以訓練的
  3. 在此特定情況下,LGBM反而在準確度表現上最差
  4. 平均來說,模型預測表現和真實值的誤差約3成,FM確實表現的最好,可以到28%,LGBM最差,可以差到33%
  5. 在此特定情況下,FM latent factor取100個,並不需要3階特徵交叉,2階已經把特徵抓的差不多了(2階和3階有近乎一致的準確度)

總結

本篇文章筆者帶大家認識了FM模型的兩篇tensorflow實作,了解了實作者們怎麼樣實作FM模型,想改的話又要從哪裡改,也發現了工程上待優化的地方,此外得到了fm模型的預測速度,準確度比較,希望本篇文章有讓讀者們從實作角度上更了解FM,程式碼也都放在這裡,若想要進行後續的分析和實驗也都可以參考,也希望本篇文章帶給讀者一些啟發!

--

--

YL-Tsai

Machine Learning Engineer with 4y+ experience | Exploring the data world | Recommendation, Search, Ad System.