テクノロジー

DeepFMを使ったCTR予測

はじめに

D2Cのビジネスエンジニアリング部リサーチチームのT.Yoshiiです。 リサーチチームでは日々新しい手法や技術を調査し、既存のサービスの精度の向上や新しい事業の立ち上げなどに役立てられないかと画策しております。 本記事を読んで少しでもリサーチチームがどんなことをしているのかを知ってもらうことができれば幸いです。

背景・目的

弊社はdocomoのdmenuやメッセージサービスを通して様々な広告を日々配信していて、配信したユーザーの年齢や加入サービスなどのデモグラフィック情報を蓄積しています。我々はこうして集められた膨大なデータを調査し配信ロジックを最適化しています。 「ユーザーが広告に興味を持ったか」の指標として、「ユーザーが広告をクリックしたか」というものが挙げられます。配信した広告に対してユーザーが広告をクリックした割合をCTR(Click Through Rate)と呼び、このCTRを予測することは広告の入札額や配信ロジックを決定する上で非常に重要なテーマとなっています。 本記事では広告のCTR予測をテーマとした論文(DeepFM: A Factorization-Machine based Neural Network for CTR Prediction)で紹介されている手法DeepFMをpytorchで実装し、通常のDNN(Deep Neural Network)より本当に精度が出るのか検証を行っていきます。

DeepFM

DeepFMは2017年に提案された論文で、通常のDNNに更に特徴量どうしの交互作用について計算できるFM(Factorization Machines)を加えたモデルとなっています。概要については以下の図のように記すことができます。
DeepFM(Huifeng et al. 2017)の概略図。quadratic-termとDNNは同一のEmbedding層を経由しています。
FMモデルでは二次の交互作用についてまで考慮し、一次の交互作用と二次の交互作用の和で表されます。 \begin{eqnarray}
y_{\rm{FM}}
&=&
<w,x> +
\sum_{j_1 = 1}^{d}
\sum_{j_2 = j_1 + 1}^{d}
x_{j_1}{\cdot}x_{j_2} \\
&=& (\rm{LinearTerm}) + (\rm{QuadraticTerm})
\end{eqnarray} FMモデル内で一次と二次の交互作用の項をそれぞれLinearTerm, QuadraticTermとしました。 従って、DeepFMはDNNの項とFMの項の足し合わせで表現されるので、 \begin{eqnarray}
y_{\rm{DeepFM}}
&=&
y_{\rm{DNN}} + y_{\rm{FM}} \\
&=&
y_{\rm{DNN}} + (\rm{LinearTerm}) + (\rm{QuadraticTerm})
\end{eqnarray} と表すことができます。 LinearTerm, QuadraticTerm, DNNは独立に学習することができるので、それぞれの精度を比較するため上のように命名しておきました。

実験

データ

検証では論文で使われていたCriteoLabsの提供するデータセットを使用します。このデータセットは以前kaggleのコンペで採用されたもので、Criteo社の配信するディスプレイ広告についてCTR予測を行うことができます。データの内容について以下に列挙します。

データ内容

  • train.csv: Criteo社の一週間分のトラフィックデータの一部を使ったトレーニングデータ。
  • 各行がCriteoが配信したディスプレイ広告に対応している。
  • データサイズの都合上正例と負例は異なる割合でサブサンプリングされている。
  • 正例と負例の割合は1:3
  • データは時系列順に並んでいる。
  • 4,580万行
カラム
  • Label: ターゲット変数。広告をクリックしてれば1、してなければ0
  • l1-l13: 整数値で表せる特徴量。
  • C1-C26: カテゴリカル変数で表せる特徴量。すべての変数は32bitのハッシュ値で匿名化されている。
  • 特徴量の内容は明らかにされていない
テストデータ(test.csv)も用意されていましたが、Labelが公開されていないためトレーニングデータの一部を使って評価を行います。 データ全体を使うと計算に時間がかかりメモリを膨大に必要とするため、今回はデータ量の1/4(4の倍数行)だけを抽出して使用することにしました。また、論文ではトレーニングデータとテストデータをランダムサンプリングしていましたが、Criteoのデータセットが時系列順になっていること、実用上は学習した期間のデータとは異なる機関のデータを予測必要があることを考えて、次のようにトレーニングデータとテストデータを分類します。
train.csvの切り分け方。train.csvが時系列順に並んでいる1週間分のデータであることを考慮して、先頭から6/7をトレーニング、残りの1/7をテストデータとして使用します。さらに、トレーニングデータのうち1/5をvalidationとして使い、hold outで精度の検証を行います。

モデル

検証で使用するモデルはLinearTerm, QuadraticTerm, DNN, DeepFMと昨今kaggleなどで驚異的な精度を示している勾配ブースティング系のモデルであるLightGBM(LGBM)として、精度の比較を行います。 LinearTerm, QuadraticTerm, DNN, DeepFMのNN系の学習をする際すべてのモデルでoptimizer: Adam(lr=0.01)、Batch Size: $2^{15}$、Epoch: 50 としました。 LGBMの学習パラメータは今回使用するデータセットサイズが大きいこと、不均衡さが著しくないことを考えて、learning_rate: 0.01とnum_boost_round: 10000とし、それ以外のパラメータは公式ドキュメントの初期パラメータで学習しました。

評価関数

CTR予測ではモデルの分類性能の他に確率値の精度も比較したいので、モデルごとにAUC、LogLossの2種類の評価関数を使用してモデルの精度を検証しました。モデルを学習する際、ValidationデータのLogLossの値が最小になるところで、最適な学習ステップと決定しました。その最適なステップまでの学習時間も今回は比較します。

実装環境

NN系の学習ではAWSのインスタンスタイプ:p3.2xlarge(NVIDIA Tesla V100 GPU メモリ16GB, 8CPU メモリ61GB)、LGBMではGCPで32CPU メモリ208 GBで検証しました。

結果

今回の検証で得られた結果は以下のようになりました。

モデルごとの精度

modelAUCLogLoss計算時間 [min]
LinearTerm0.77410.4762~94
QuadraticTerm0.76660.4824~70
DNN0.77880.4732~8
DeepFM0.78190.4700~8
DeepFM(LinearTerm pre-train)0.78630.4661~12 (+ ~94)
LGBM0.79510.457872
FMモデルはパラメータが少なく学習が安定するまで時間がかかりましたが、DNN、DeepFMは反対にパラメータ数が多いためすぐに収束し過学習を始めてしまいました。
DeepFMは学習する際、QuadraticTermとDNNは特徴量を同一の方法でEmbeddingしますが、LinearTermのみ別のEmbedding層を通します。 このことからDeepFMの学習をする際、LinearTermだけを先に学習しあとからDeepFMを学習させた場合も比較用のモデルとして採用しました。 論文で紹介されいる結果(DeepFM AUC: 0.8007, LogLoss:0.45083)と同程度の精度が出ていないのは、データを縮小したのとテストデータのサンプリング方法が異なることが挙げられます。ですが、同じ学習条件でDNNよりもAUC、LogLossの全てに対して精度が出ていることが確認できました。 LGBMはやはり精度が高く、DeepFMでも精度を抜くことができませんでした。LGBMはパラメータチューニングをすることでもう少し精度を高めることができるかもしれませんが今回の目的から逸脱してしまうため、検証はここまでとします。

バギングの効果

ついでに、勾配ブースティングとNNで異なるロジックで分類をしていると思われるので、バギングでどれくらい精度が向上するのかを調べてみました。
modelAUCLogLoss
LGBM + DeepFM0.79670.4569

まとめ

今回はpytorchで実装したDeepFMを使ってCriteoのデータセットでCTR予測を行い論文で主張するようにDNN以上の精度が出ているかの検証を行いました。CTR予測において、よりよいモデルを突き詰めることは広告主にとってもユーザーにとってもメリットとなります。こうした既存のサービスの精度向上につなげるためにも我々リサーチチームは常に新しい技術や手法の調査を行っています。


関連タグ