Paoの技術力を磨くブログ

機械学習やブロックチェーン等の技術を身に付けていくブログです。

要注目?DeepGBM: ニューラルネット+GBDT(速報)

KDD2019のPaper一覧で気になるものがあったので紹介します。

※記載時点でまだ論文公開、発表されておらず、こちら鮮度重視の記事です。 内容に誤りがある可能性は十分あるのでご了承ください。

DeepGBMとは

データマイニングのトップカンファレンスKDD2019で発表される予定の手法です。

Guolin Ke, Zhenhui Xu, Jia Zhang, Jiang Bian, and Tie-yan Liu. "DeepGBM: A Deep Learning Framework Distilled by GBDT for Online Prediction Tasks." In Proceedings of the 25th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining, ACM, 2019.

発表概要

  • KDD2019のOralで発表予定
  • 1st AuthorはLightGBMのAuthor(Microsoft)

問題意識

テーブルデータをベースとしたオンライン予測タスクについて以下の課題がある。 - GBDT(勾配ブースティング)はスパースなカテゴリに対して、弱い - NN(ニューラルネットワーク)はNumericなデータに対して弱い

提案手法

  • Numericなデータに対して、GBDTでの学習をNNで蒸留する
  • Categoricalなデータは、カテゴリをembeddingして、かつ相互作用を考慮したNN
  • NumericとCategoricalのoutputをconcatしてfull connect layerにつなげる

f:id:go5paopao:20190703193250j:plain

実験結果

  • 多くのデータセットに対して、DeepFMやGBDTよりも良い結果
  • ちなみにGBDTはLightGBMを利用

※データセットMalwareというのがあったのが気になった。これまさか。。?

個人的感想

テーブルデータではGBDTのほうが良い精度が出ることが多いが、確かにスパースなカテゴリデータには強くないと思う。CountEncodingやTargetEncodingで何とかすることが多い気がする。
逆にNNはカテゴリにはembedding層でいい感じにできるイメージ。
そこでNNとGBDT混ぜてやろうってのは良さそうに聞こえる。実際kaggleの上位もほぼNNとGBDTのアンサンブルだし。

GBDTをNNで蒸留するところが、鍵だと思うけど、論文も出てないし、いまいち理解できてない。(理解するまえにブログ書いた。) Githubでソースは公開されてるので、時間ある方いたら読んで教えてください(もしくは論文待とう)

LigbtGBMの人だし、KDDで採択されてるし、信頼できそう? 要注目です。

リンク

論文

まだ公開されてなさそう。 公開されたら更新します。

github

https://github.com/motefly/DeepGBM/blob/master/README.md

概要スライド

https://youtu.be/UzXNzW2s8Pw