7th place approach (upura part)

Summary

  1. 探索的データ分析で、データが4分類できると特定
  2. 得た知見を踏まえ、2000程度の特徴量を作成
  3. LightGBMのfeature importance上位50件を抽出
  4. Neural Networkに投入(CV: 20.6, LB: 20.8)

1. 探索的データ分析

データを少し眺めると、現実世界ではあり得ない綺麗なデータである印象を受けました。恐らく人為的に作られたデータなので、何かしらの軸で区切った際に特徴的な分布が存在するだろうという仮定のもと、探索的なデータ分析を実行しました。

結果として、後にdiscussionに投稿されたデータの4分類に気づきました。

image.png

2. 特徴量エンジニアリング

以下のように is_tokyo_osaka_and_partner というデータの4分類を示す特徴量を作成しました。この特徴量は、Cross validationの際にも利用しています。

from sklearn import preprocessing
from sklearn.model_selection import StratifiedKFold
​
​
def label_encoding(train, test, target_cols):
    for f in target_cols:
        lbl = preprocessing.LabelEncoder()
        lbl.fit(list(train[f].values) + list(test[f].values))
        train[f] = lbl.transform(list(train[f].values))
        test[f] = lbl.transform(list(test[f].values))
    return train, test
​
​
train['is_tokyo_osaka'] = train['area'].isin(['東京都', '大阪府']).astype(int)
test['is_tokyo_osaka'] = test['area'].isin(['東京都', '大阪府']).astype(int)
​
train['is_tokyo_osaka_and_partner'] = train['is_tokyo_osaka'].astype(str) + train['partner'].astype(str)
test['is_tokyo_osaka_and_partner'] = test['is_tokyo_osaka'].astype(str) + test['partner'].astype(str)
​
encode_cols = ['is_tokyo_osaka_and_partner']
train, test = label_encoding(train, test, encode_cols)
​
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=71)
​
train['fold_id'] = np.nan
for i, (train_index, valid_index) in enumerate(cv.split(train, train['is_tokyo_osaka_and_partner'])):
    train.loc[valid_index, 'fold_id'] = i
​
train['fold_id'].to_csv('../outputs/model/fold_id.csv', index=False)

それぞれの分類の中で commutesalary には、線形のような関係があります。この辺りの特徴を捉えられるような特徴量を追加しました。例えば「同一の is_tokyo_osaka_and_partner 内での commute の平均との差分」などです。

3. 特徴選択

LightGBMのfeature importance上位50件を抽出しました。上位100, 150, 200件の場合も試しましたが、性能に大きな違いが出なかったので計算量を重視し50件を採用しています。LightGBMでLB21.2程度でした。

AwesomeScreenshot-github-upura-probspace-salary-prediction-blob-master-experiments-exp18_lgbm_baseline_fi_param01.ipynb-2019-12-23_11_50_part1.png

高画質版はこちら

4. Neural Network

この段階でmaguchiiさんとチームを組みました。「Neural Networkでベストスコアを出している」という情報を共有していただき、自分もLightGBMからNeural Networkに差し替えました。

最終的なNeural Networkの構造は次の通りです。

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
area (InputLayer)               (None, 1)            0                                            
__________________________________________________________________________________________________
is_tokyo_osaka_and_partner (Inp (None, 1)            0                                            
__________________________________________________________________________________________________
embedding_3 (Embedding)         (None, 1, 4)         188         area[0][0]                       
__________________________________________________________________________________________________
embedding_4 (Embedding)         (None, 1, 4)         16          is_tokyo_osaka_and_partner[0][0] 
__________________________________________________________________________________________________
numerical (InputLayer)          (None, 48)           0                                            
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 1, 8)         0           embedding_3[0][0]                
                                                                 embedding_4[0][0]                
__________________________________________________________________________________________________
dense_9 (Dense)                 (None, 32)           1568        numerical[0][0]                  
__________________________________________________________________________________________________
flatten_2 (Flatten)             (None, 8)            0           concatenate_3[0][0]              
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 32)           128         dense_9[0][0]                    
__________________________________________________________________________________________________
dense_8 (Dense)                 (None, 4)            36          flatten_2[0][0]                  
__________________________________________________________________________________________________
p_re_lu_4 (PReLU)               (None, 32)           32          batch_normalization_5[0][0]      
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 4)            16          dense_8[0][0]                    
__________________________________________________________________________________________________
dropout_3 (Dropout)             (None, 32)           0           p_re_lu_4[0][0]                  
__________________________________________________________________________________________________
p_re_lu_3 (PReLU)               (None, 4)            4           batch_normalization_4[0][0]      
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 36)           0           dropout_3[0][0]                  
                                                                 p_re_lu_3[0][0]                  
__________________________________________________________________________________________________
dense_10 (Dense)                (None, 4)            148         concatenate_4[0][0]              
__________________________________________________________________________________________________
dense_11 (Dense)                (None, 36)           180         dense_10[0][0]                   
__________________________________________________________________________________________________
multiply_2 (Multiply)           (None, 36)           0           concatenate_4[0][0]              
                                                                 dense_11[0][0]                   
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 36)           144         multiply_2[0][0]                 
__________________________________________________________________________________________________
dropout_4 (Dropout)             (None, 36)           0           batch_normalization_6[0][0]      
__________________________________________________________________________________________________
dense_12 (Dense)                (None, 1000)         37000       dropout_4[0][0]                  
__________________________________________________________________________________________________
dense_13 (Dense)                (None, 800)          800800      dense_12[0][0]                   
__________________________________________________________________________________________________
dense_14 (Dense)                (None, 300)          240300      dense_13[0][0]                   
__________________________________________________________________________________________________
out1 (Dense)                    (None, 1)            301         dense_14[0][0]                   
==================================================================================================
Total params: 1,080,861
Trainable params: 1,080,717
Non-trainable params: 144
__________________________________________________________________________________________________

おわりに

恐らく人為的なデータという特異性からか、Neural Networkが性能を大いに発揮するコンペティションだったと思います。チームを組み、テーブルデータに対するNeural Networkの利用などに関して議論を交わしてくださったmaguchiiさんにお礼申し上げます。

おまけ

import pandas as pd


train = pd.read_csv('../datasets/data/train_data.csv')
test = pd.read_csv('../datasets/data/test_data.csv')
df = pd.concat([train, test], sort=False)
df.query('position==0 and age==22 and area=="岡山県" and sex==2 and partner==0 and num_child==0 and education==0 and service_length==4 and study_time==0.0 and overtime==0.0')

Screen_Shot_2019-12-23_at_12.12.11.png

Favicon
new user
コメントするには 新規登録 もしくは ログイン が必要です。