表形式データの分類(TabNet)

はじめに

表形式データの分類モデルを構築します。
コードは下記にあります。
https://github.com/toshi-4886/neural_network/blob/main/tabular/2_TabNet.ipynb

概要

  • adultデータセットを用いて収入を予測するモデルを構築します。

  • モデルはTabNetを使用します。

TabNet

TabNetは表形式データのために提案されたニューラルネットワークです。
下記の提案論文の図に示されている通り、モデルの主な特徴は2点あります。

  1. Attentive Transformerという機構を利用して、入力の少数の特徴を選択して予測

  2. 上記を繰り返し、各予測を統合して全体の予測結果とする。ただし、2回目以降は、前回の予測も入力に加える。

実装

1. ライブラリのインポート

TabNetはpytorch_tabularの実装を使用します。

!pip install ucimlrepo
from ucimlrepo import fetch_ucirepo

!pip install pytorch_tabular

import sys
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import collections
from sklearn.model_selection import train_test_split
from sklearn import metrics

import torch
import torch.nn as nn
import torch.nn.functional as F

2. 実行環境の確認

使用するライブラリのバージョンや、GPU環境を確認します。

print('Python:', sys.version)
print('PyTorch:', torch.__version__)
!nvidia-smi
Python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
PyTorch: 2.1.0+cu121
Sat Jan 20 05:20:32 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   39C    P8               9W /  70W |      0MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+

3. データセットの用意

adultデータセットをダウンロードして、学習に使用できる形式に整形します。

adult = fetch_ucirepo(id=2)

X = adult.data.features
y = adult.data.targets['income']

y = y.replace({'<=50K.': 0, '<=50K':0, '>50K.': 1, '>50K': 1})

# カテゴリ変数の特定
categorical = X.columns[X.dtypes == 'object'].tolist()
continuous = X.columns[X.dtypes != 'object'].tolist()


# 教師データとテストデータにランダムに分割
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

print(X_train.shape, X_test.shape)
print(collections.Counter(y_train), collections.Counter(y_test))

X_train['income'] = y_train
X_test['income'] = y_test

4. ニューラルネットワークの定義

学習などの設定は、configに引数として渡し、TabularModelを作成します。

from pytorch_tabular import TabularModel
from pytorch_tabular.models import TabNetModelConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig, ExperimentConfig

data_config = DataConfig(
    target=['income'],
    continuous_cols=continuous,
    categorical_cols=categorical,
)
trainer_config = TrainerConfig(
    auto_lr_find=False,
    batch_size=128,
    max_epochs=100,
)
optimizer_config = OptimizerConfig(
    optimizer_params = {'weight_decay':1e-4}
)
model_config = TabNetModelConfig(
    task="classification",
)

tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
)

5. 学習

ニューラルネットワークの学習を行います。

tabular_model.fit(train=X_train)

6. 学習結果の表示

テストデータの損失と精度を評価します。

res = tabular_model.evaluate(X_test)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metricDataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy       │    0.8428308367729187     │
│         test_loss         │    0.3377942144870758     │
└───────────────────────────┴───────────────────────────┘

おわりに

今回の結果

今回の設定では、テスト精度は84%程度となりました。
全結合ニューラルネットワークよりも少し低い結果となっています。
ただし、データセットやハイパーパラメータによっても性能は異なるため、もう少し検証は必要です。

次にやること

他の表形式データのために提案されたニューラルネットワークも試してみようと思います。

参考資料

いいなと思ったら応援しよう!