見出し画像

【実装付きAI論文解説】TableGPT2: A Large Multimodal Model with Tabular Data Integration|TECH BLOG #02


1. はじめに

はじめまして、AIVALIX株式会社の取締役CIO(Chief Information Officer)を務めております、田中勇吾と申します。AIVALIXでは、LLMを活用したRAG(Retrieval-Augmented Generation)システムの開発をはじめ、ホームページやモバイルアプリの開発など、幅広い技術領域を担当しております。

さて、弊社ではAIエンジニアによるテックブログの発信に力を入れております。このテックブログでは、最新のAI論文について実装付きで分かりやすく解説することを目的としています。少しでも多くの方にAIの技術や可能性を感じていただければ幸いです。

今回は、テックブログ第2回として、表データを扱うために特化して訓練されたモデル「TableGPT2」について解説をまとめました。ぜひ最後までお読みいただけると嬉しいです。また、フォローしていただけますと今後の発信の励みになります。よろしくお願いいたします。

2. TableGPT2 論文解説

表形式データは、実世界の様々な領域で重要な役割を担っていますが、テキストや画像と比較すると、LLMを用いた表データの処理の研究は進んでいません。

TableGPT2は、大規模な表データセットを用いて事前学習、ファインチューニングされたモデルです。広範なトレーニングにより、TableGPT2 は、一般言語とコーディングの能力を維持しながら、表タスクに秀でています。既存手法では対応できないような、現実の大規模で複雑な表データにも対応できると述べられています。

2-1. Methods

TableGPT2は、表データを解釈するセマンティックテーブルエンコーダーと、Qwen2.5モデルをベースとしトレーニングされたLLMからなります。

TableGPT2で最も特徴的なのが、セマンティックテーブルエンコーダーです。

この仕組みは以下のようなLLMの課題を解決し、表データの構造的な理解を目指します。

  • 構造の不適合:表形式データの二次元構造(行×列)は、LLMが主に処理するテキストデータの一次元シーケンシャル構造とは大きく異なるため、LLMは表特有のパターンを理解する能力が低い。

  • 表現の非効率性:大規模な表やデータベースをLLMに与えることは非常にトークン効率が悪い。

  • 一般化能力の欠如:実世界の表はノイズや曖昧さを含む場合が多く、これをそのままモデルに入力すると誤解釈を招く。表全体の一部のみをコンテキストに含めた場合、モデルが偏った解釈をする可能性がある。

また、Qwen2.5モデルは以下の2つの主要なステップを経て訓練されています。

  1. 継続的事前学習(CPT): コード能力や推論能力を強化。

  2. 教師ありファインチューニング(SFT): 特定のデータ分析タスクに対応する性能を向上。


セマンティックテーブルエンコーダー(Semantic Table Encoder)

アーキテクチャ
セマンティックテーブルエンコーダーは、以下の3つの主要要素で構成されています。

1️⃣ セマンティックテーブルエンコーダー

表形式データを入力として、構造を反映したセマンティック埋め込み(semantic embedding)を生成します。

まず表 T にsentence transformer φ を適用して各セルのデータをエンべディングします。

その後、二次元注意モジュール(Bi-Dimensional Attention)で表全体の構造的意味を学習します。

BERT形式の双方向アテンションを行方向と列方向に交互に適用し、列内の分布特性と列間の関係の両方を捉えられるようにしています。

表では行や列の順序はあまり重要でないため、行や列の位置に依存しないように、ポジショナルエンコーディングを使用しない設計となっています。

2️⃣ テーブル-テキストアダプター(Table-Language Adapter)

Adapterを導入し、セルレベルの埋め込みを列レベルで集約し、LLMのテキスト埋め込みと統合可能な形式に変換します。

k個の学習可能なクエリと各列のセル埋め込み間でQ-Former形式のクロスアテンションを実行し、m行の表を、LLMの埋め込み次元d'に合わせたサイズkの固定長の列表現に変換します。

3️⃣ 動的なコンテキスト統合

表形式データ(列埋め込み)とメタデータ(列名、データ型など)を統合するために、以下の統一形式を使用します。この情報をLLMのプロンプトに挿入します。


トレーニング
Figure2の右上に示されているように、3つのステップでエンコーダをトレーニングしています。

  • エンコーダ事前学習(Encoder Pre-Training)

セマンティックテーブルエンコーダーが表形式データの基本的な構造と意味を理解できるように事前学習を行います。コントラスト学習(対照学習)を用い、同じ列の異なるセルのデータは近いもの、異なる列のデータは遠いものと判断するように学習します。これにより、同じ列内の一貫した特徴を学習し、異なる列間の違いを強調します。

  • 特徴の整合性(Feature Alignment)

表形式データと自然言語テキストの特徴量を整合させるために、データセットを構築し、テーブルエンコーダーとアダプター、LLMを共同でトレーニングします。表データセットを用いて、セルが属する列の予測、列に対応するセル値の予測、質問生成、表題生成、行の要約などのタスクを行います。

  • 共同指示チューニング(Instruction Tuning)

下流タスクでの最適なパフォーマンスのために、セマンティックテーブルエンコーダー、アダプター、およびLLMの共同での教師ありファインチューニングを実施します。


継続的事前学習(Continual Pretraining, CPT)
TableGPT2は、Qwen-2.5をベースとし、表形式データに特化した能力を持たせるために、継続的事前学習を実施しています。特に、コード生成やデータ解析に必要なスキルを強化することを目的としています。

収集データについて

  • コードデータ: StackOverflow (stack-v2) やGitHubから収集。Python(51.2%)、SQL(12.8%)など。

  • 一般データ: 金融や数学など専門分野の教科書やKaggleなどのデータ解析資料。

  • データ量は86B(860億)トークンに達し、データの質を保証するために文書レベルとトークンレベルで選別。

    • 文書レベル: 品質が低いデータや不適切なデータを除外。

    • トークンレベル: コーパス内の各トークンの損失をスコア付けし、閾値0.6を下回るトークンは損失計算から除外し、損失を大幅に削減する高価値トークンのみを選別して使用。


教師ありファインチューニング(Supervised Fine-Tuning, SFT)

柔軟で動的なデータ分析タスクに対応するため、既存の表形式データセットに加え、115Kの表と479Kのサンプルを含むデータを収集し、ファインチューニングを行っています。

データセットは、単純な質問応答だけでなく、より複雑な操作(多段階の分析やツール使用)も含んでいます。表データだけでなく、他の形式のデータも利用することで、一般的なタスクに対するパフォーマンスも維持することを目指しました。

収集データについて

  1. 表データの収集

    1. 様々な分野や形式から10万件以上の表データを収集し、多様性を確保しています。

      • データベーステーブル: MySQL、PostgreSQLなどの公的データセット。

      • ウェブページのテーブル: 学術論文や公式ウェブサイトからの単純な表データ。

      • Excelテーブル: 政府データや財務報告書など、不規則な構造を含む表。

      • 研究用のテーブル: 研究で用いられるデータを含む表で、TableQAやNL2SQL向きの構造化された形式であることが多い。

      • 特殊フォーマットのテーブル: 請求書やレシートなどの特殊な形式と用途の表。

      • 予測タスク用テーブル: KaggleやUCI Machine Learning Repositoryなどの予測タスクに用いられる表。

  2. クエリと回答の生成

    1. 表に対し、GPT-4oなどのLLMを流用して200万件以上の多様なクエリと回答を生成しました。

    2. データの取得・挿入・削除・変更、計算、統計、データの可視化のクエリカテゴリがあります。また、ユーザーにさらなる情報を求める必要がある「曖昧な質問」というカテゴリも設定されています。

不正確・冗長なデータを除去してデータ品質を確保しています。また、データ拡張を行い、多様性を持たせたデータセットを構築することで、モデルの汎用性とロバスト性を向上させます。

以上で得られたデータセットを用いて、テーブルエンコーダーとLLMに対し教師付きファインチューニングを行います。

一般タスク(自然言語応答など)と表タスク(コード生成、欠損値補完、データ操作など)の両方でモデルをファインチューニングすることで、特定のタスクに偏らず、汎用性を維持します。タスク内容に応じて、エンコーダーのみを更新するモードと、LLMを含むモデル全体をトレーニングするモードを動的に切り替えます。


エージェントフレームワーク(Agent Framework)

論文中では、実世界におけるモデルの応用という点を重視しており、TableGPT2を実世界のアプリケーションに統合し、データ解析、可視化、自動化をサポートするためのエージェントフレームワークについても述べられています。LLMを用いた一般的なエージェントと同様のシステムなので、今回は割愛します。

1-2. Result

表関連タスク向けのモデル(TableLLMs)とオープンソースの汎用LLMが比較対象として用いられています。

評価には、6カテゴリの既存ベンチマークと、新しいベンチマーク「RealTabBench」を使用しました。

既存ベンチマーク

  • Table Understanding: 表理解

    • TURL:表の解釈と拡張に関する4つのタスクからなる表理解のベンチマーク。

    • 表の基本的な理解能力を評価する。

  • Table Question Answering(TableQA): 質問応答

    • 表データに基づき、自然言語での質問に答える。

    • WikiTableQuestion (WikiTQ), WikiSQL: 標準的なリレーショナルテーブルのQAデータセット。

    • HiTab: 複雑な階層型テーブルのQAデータセット。

    • FeTaQA: 自由形式の質問応答。

    • HybridQA: 表とテキストデータからの多段階推論。

    • 複雑な表から情報を解釈し、正確な回答を提供するモデルの能力を調べる。

  • Fact Verification: 事実検証

    • TabFact, FEVEROUSの2つのベンチマーク。

    • 表の情報に基づき記述の正確性を判断するモデルの能力を評価する。

  • Table to Text generation(Table2Text): 表からテキスト生成

    • ToTTo:構造化された表データから、首尾一貫した有益なテキスト記述を生成する。

    • 表形式の情報を効果的に要約し、言語化するモデルの能力を評価する。

  • Natural Language to SQL(NL2SQL): 自然言語からSQL文生成

    • 表データを合成・拡張し、他のモデルやシステムで利用可能にする。

  • Holistic Table Evaluation: 全体的な表の評価

    • TableBench:18の産業分野にわたる、手動で注釈された表質問応答(TableQA)ベンチマーク。

    • 事実検証、数値推論、データ分析、コードベースのチャート可視化の4つの重要な表分析タスク。

    • symbolic chain-of-thought (SCoT), textual chain-of-thought (TCoT), program-of-thought (PoT), and direct prompting (DP)の4つの推論方法を使用。

    • 実世界のシナリオにおけるLLMの複雑な推論能力を評価する。

新ベンチマーク「RealTabBench

  • 既存のベンチマークは豊富にあるが、ほとんどが比較的単純なタスクに焦点を当てており、実世界の使用シナリオと一致していない。より実用的なデータセットとして、ビジネスインテリジェンス(BI)シナリオの実世界の表から360の複雑なデータ表を収集し、それらに基づいて6,000の現実的で複雑な問い合わせ文を作成した。

  • 特に実用上問題となる以下の2つの表の特性に着目した。

    • 曖昧性(ambiguity):曖昧な項目を含む表を扱い、モデルの解釈能力を調べる。

    • 不規則性(irregularity):セルの結合や不均一なレイアウトなど、不規則なレイアウトや複雑なフォーマットを持つ表を扱う。

  • 一貫性、情報の完全性、セキュリティの3つの観点から評価。

  • 人間の評価者と評価用LLMを組み合わせてスコアを算出。

評価結果

Table5は、TableGPT2(7B及び72B)を、上記のベンチマークにおいて他のモデルと比較した結果です。TableGPT2はほぼすべてのベンチマークにおいて他のモデルよりも優れた結果を達成しています。また、TableGPT2はGPT-4oと同等またはそれ以上の結果を得ています。特に、複雑な階層型テーブルのデータセット「HiTab」において、現在のほとんどのLLMは低いスコアを示している一方で、TableGPT2は、Qwen2.5ベースのモデルに対して実行精度を60%以上改善するという大幅な向上を示しています。

Table6はRealTabBenchデータセットでの結果です。TableGPT2-7Bは全てのタスクで最高のスコアを示しています。

一般的な能力

Table7は、表分析に限らない一般的なモデルの能力を示しています。コーディング能力や一般的な言語理解能力を測るベンチマークにおいては、表分析に特化したTableLLMのスコアは低いです。一方、TableGPT2はこれらのタスクにおいても汎用モデルに匹敵する性能を維持しています。

3. TableGPT2 実装してみた

TableGPT2のモデルはHugging Faceで公開されており、ダウンロードして利用することができます。非常に簡易的ではありますが、Google ColabでTableGPT-7Bモデルを動かしてみた結果を以下に記します。

Google Colabの有料版でA100 GPUを利用しました。無料版でT4 GPUを利用するとギリギリ動くかどうかというくらいです。
モデルのダウンロード及び実行はHugging Faceのページの「QuickStart」の項に従って行いました。

モデルのダウンロード

!pip install transformers>=4.37.0

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("tablegpt/TableGPT2-7B", torch_dtype='auto', device_map='auto')
tokenizer = AutoTokenizer.from_pretrained("tablegpt/TableGPT2-7B")

if torch.cuda.is_available():
    model = model.to("cuda")

モデルの実行

example_prompt_template = """Given access to several pandas dataframes, write the Python code to answer the user's question.

/*
"{var_name}.head().to_string(index=False)" as follows:
{df_info}
*/

Question: {user_question}
"""

question = "昭和25年の埼玉県の男性人口を教えて。"

prompt = example_prompt_template.format(
    var_name="df",
    df_info=df.head().to_string(index=False),
    user_question=question,
)

messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": prompt},
]
text = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

generated_ids = model.generate(**model_inputs, max_new_tokens=512)
generated_ids = [
    output_ids[len(input_ids) :]
    for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

政府統計のページから、国勢調査の「男女別人口ー全国,都道府県(大正9年〜平成27年)」のcsvファイルをダウンロードし、利用しました。大正9年から平成27年まで5年ごとに、全国及び各都道府県の人口(総数・男女別)を記載したデータです。この表データについて、質問をしてみます。

質問:「昭和25年の埼玉県の男性人口を教えて。」
回答:
我需要从数据框中筛选出昭和25年(1950年)的埼玉县的数据,并提取其中的男性人口信息。

Python code:

# 筛选出昭和25年(1950年)的埼玉县的数据
saitama_population_1950 = df[(df["元号"] == "昭和") & (df["西暦(年)"] == 1950) & (df["都道府県名"] == "埼玉県")]

# 提取男性人口信息
saitama_population_male_1950 = saitama_population_1950["人口(男)"].values[0]
print(saitama_population_male_1950)

5秒ほどで回答が返ってきました。日本語出力には対応していないようで、中国語で回答が出力されています。答えを得るためのPythonコードが出力されており、年を和暦ではなく西暦で指定している点は気になりますが、正しそうです。

このコードを実行すると「1049695」と正しい答えが返ってきます。

続いてもう1つ質問をしてみます。
質問:「埼玉県の人口が最も多いのは何年ですか。」
回答:
我需要筛选出都道府县名为“埼玉県”的行,然后找出该行中人口(総数)最大的年份。

Python code:

# 筛选出都道府县名为“埼玉県”的行
saitama_population = df[df["都道府県名"] == "埼玉県"]

# 找出人口(総数)最大的年份
max_population_year = saitama_population.loc[
    saitama_population["人口(総数)"].idxmax(), "和暦(年)"
]
print(max_population_year)

こちらも5秒ほどで回答が返ってきました。コードを実行すると「27.0」という数値が返ってきます。和暦だけを表示していますが、これは平成27年を指しているようです。実際にデータを見てみると、埼玉県の人口は年々増えており、平成27年が最多となっています。

今回はTableGPT2単体で試してみましたが、論文中にもあるように、Agentフレームワークに組み込み、Pythonコードを実行できる環境で用いると、コードの実行結果をフィードバックして回答を改善できるため、複雑な質問にも正確な返答ができることが期待されます。

TableGPT2は、テキストとは異なる表形式データの構造に着目し、表に特化したエンコーダを作成した点が革新的です。一方、複雑な表形式や質問に対してはまだ精度が不十分なところもあるように思います。表形式のデータの分析は様々な場面で必要とされるタスクなので、LLMで自動化することができれば非常に便利になるでしょう。今後さらなる精度の高い手法が出てくるのか、注目したいところです。

今回のテックブログは以上になります。最後までお読みいただき、ありがとうございました。この記事が、少しでも皆様の参考になり、新たな気づきや挑戦のきっかけとなれば幸いです。今後も引き続き、有益な情報をお届けしてまいりますので、どうぞご期待ください。


AIVALIX株式会社では、AIに関する開発やR&Dの案件を常時承っております。また、共同開発・共同研究にご協力いただける企業様も広く募集しております。

AI技術を通じた新たな価値創造に向け、一緒に挑戦していただけるパートナーを心よりお待ちしております。ご興味をお持ちの際は、ぜひお気軽にお問い合わせください。

▼ コーポレートサイトはこちら ▼


いいなと思ったら応援しよう!