表形式データの分類(TabNet)
はじめに
表形式データの分類モデルを構築します。
コードは下記にあります。
https://github.com/toshi-4886/neural_network/blob/main/tabular/2_TabNet.ipynb
概要
adultデータセットを用いて収入を予測するモデルを構築します。
モデルはTabNetを使用します。
TabNet
TabNetは表形式データのために提案されたニューラルネットワークです。
下記の提案論文の図に示されている通り、モデルの主な特徴は2点あります。
Attentive Transformerという機構を利用して、入力の少数の特徴を選択して予測
上記を繰り返し、各予測を統合して全体の予測結果とする。ただし、2回目以降は、前回の予測も入力に加える。
実装
1. ライブラリのインポート
TabNetはpytorch_tabularの実装を使用します。
!pip install ucimlrepo
from ucimlrepo import fetch_ucirepo
!pip install pytorch_tabular
import sys
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import collections
from sklearn.model_selection import train_test_split
from sklearn import metrics
import torch
import torch.nn as nn
import torch.nn.functional as F
2. 実行環境の確認
使用するライブラリのバージョンや、GPU環境を確認します。
print('Python:', sys.version)
print('PyTorch:', torch.__version__)
!nvidia-smi
Python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
PyTorch: 2.1.0+cu121
Sat Jan 20 05:20:32 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05 Driver Version: 535.104.05 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |
| N/A 39C P8 9W / 70W | 0MiB / 15360MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| No running processes found |
+---------------------------------------------------------------------------------------+
3. データセットの用意
adultデータセットをダウンロードして、学習に使用できる形式に整形します。
adult = fetch_ucirepo(id=2)
X = adult.data.features
y = adult.data.targets['income']
y = y.replace({'<=50K.': 0, '<=50K':0, '>50K.': 1, '>50K': 1})
# カテゴリ変数の特定
categorical = X.columns[X.dtypes == 'object'].tolist()
continuous = X.columns[X.dtypes != 'object'].tolist()
# 教師データとテストデータにランダムに分割
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
print(X_train.shape, X_test.shape)
print(collections.Counter(y_train), collections.Counter(y_test))
X_train['income'] = y_train
X_test['income'] = y_test
4. ニューラルネットワークの定義
学習などの設定は、configに引数として渡し、TabularModelを作成します。
from pytorch_tabular import TabularModel
from pytorch_tabular.models import TabNetModelConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig, ExperimentConfig
data_config = DataConfig(
target=['income'],
continuous_cols=continuous,
categorical_cols=categorical,
)
trainer_config = TrainerConfig(
auto_lr_find=False,
batch_size=128,
max_epochs=100,
)
optimizer_config = OptimizerConfig(
optimizer_params = {'weight_decay':1e-4}
)
model_config = TabNetModelConfig(
task="classification",
)
tabular_model = TabularModel(
data_config=data_config,
model_config=model_config,
optimizer_config=optimizer_config,
trainer_config=trainer_config,
)
5. 学習
ニューラルネットワークの学習を行います。
tabular_model.fit(train=X_train)
6. 学習結果の表示
テストデータの損失と精度を評価します。
res = tabular_model.evaluate(X_test)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ DataLoader 0 ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ test_accuracy │ 0.8428308367729187 │
│ test_loss │ 0.3377942144870758 │
└───────────────────────────┴───────────────────────────┘
おわりに
今回の結果
今回の設定では、テスト精度は84%程度となりました。
全結合ニューラルネットワークよりも少し低い結果となっています。
ただし、データセットやハイパーパラメータによっても性能は異なるため、もう少し検証は必要です。
次にやること
他の表形式データのために提案されたニューラルネットワークも試してみようと思います。
参考資料
S. O. Arik and T. Pfister, TabNet: Attentive Interpretable Tabular Learning, AAAI, 2021.
PyTorch Tabular
https://pytorch-tabular.readthedocs.io/en/latest/