見出し画像

ローカルLLMの長文推論、メモリ使用量を節約する方法:KVキャッシュの量子化


大規模言語モデル(LLM)において、メモリ使用量の効率化は非常に重要です。特に長文処理を行う場合です。モデルに入力するコンテクストが長くなるほど、メモリ消費量も増加します。

Mistral-7Bは、v0.1では、約4K tokensのSliding window attention(SWA)にて、コンテクスト長に消費されるメモリを低減しました。しかし、性能への影響からと考えますが、v0.2以降のアップデートで、SWAは排除されています。入力トークンを絞ることでメモリ容量を低減すれば、当然複雑性や表現力が低下してしまうのはイメージしやすいです。

KV キャッシュの量子化

Hugging Faceのブログ記事では、KVキャッシュ量子化が大規模言語モデル(LLM)での長文生成時のメモリ使用量を大幅に削減する方法を解説しています。KVキャッシュ量子化は、計算結果を保存して再利用することで、効率的なテキスト生成を実現します。数値の精度を下げる量子化により、消費者向けGPUでより長いテキスト生成が可能になります。記事では、実装方法やパフォーマンス比較、使用手順について詳しく説明されています。


generate()
メソッドは、効率を高め再計算を避けるためにキーとバリューをキャッシュする機能をサポートしていますが、キーとバリューのキャッシュはメモリを大量に消費し、特に大規模言語モデルの長文生成ではボトルネックとなります。generate()使用時にキャッシュを量子化することで、メモリ要求を大幅に削減できます。
キー・バリューキャッシュの量子化を有効にするには、generation_configcache_implementation="quantized"を指定する必要があります。量子化関連の引数は、辞書としてまたはQuantizedCacheConfigクラスのインスタンスとしてgeneration_configに渡します。デフォルトの量子化バックエンドはquantoです。

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16).to("cuda:0")
inputs = tokenizer("I like rock music because", return_tensors="pt").to(model.device)

out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="quantized", cache_config={"nbits": 4, "backend": "quanto"})
print(tokenizer.batch_decode(out, skip_special_tokens=True)[0])

out = model.generate(**inputs, do_sample=False, max_new_tokens=20)
print(tokenizer.batch_decode(out, skip_special_tokens=True)[0])

キャッシュ量子化は、文脈の長さが短く、十分なGPU VRAMが利用可能な場合には、逆に不利になることがあります。

vLLMの場合

FP8(8ビット浮動小数点)フォーマットを用いたKVキャッシュによってメモリ使用量を節減して推論するオプションが設定されています。

FP8 E5M2 KVキャッシュ

FP8データフォーマットは2〜3ビットの仮数部を保持し、float/fp16/bfloat16およびfp8間の変換が可能です。

FP8 E5M2の設定例

以下のコードは、FP8 E5M2フォーマットを使用してKVキャッシュを有効にする方法を示しています。

from vllm import LLM, SamplingParams

# サンプルプロンプト
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]

# サンプリングパラメータオブジェクトの作成
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# LLMの作成
llm = LLM(model="facebook/opt-125m", kv_cache_dtype="fp8")

# プロンプトからテキストを生成
outputs = llm.generate(prompts, sampling_params)

# 出力結果の表示
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

注意点として、現在のプレフィックスキャッシュはFP8 KVキャッシュが有効な場合には機能しません。異なるKVおよびキャッシュタイプに対応するため、forward_prefixカーネルが必要です。

FP8 E4M3 KVキャッシュ

FP8 E4M3フォーマットは、メモリのフットプリントをさらに減少させます。これにより、キャッシュに保存できるトークン数が増加し、スループットが向上します。FP8 E4M3は、4ビットの指数部と3ビットの仮数部を持ち、より高精度な浮動小数点数の表現が可能です。

FP8 E4M3の設定例

以下のコードは、FP8 E4M3フォーマットを使用してKVキャッシュを有効にする方法を示しています。

from vllm import LLM, SamplingParams

sampling_params = SamplingParams(temperature=1.3, top_p=0.8)

llm = LLM(model="meta-llama/Llama-2-7b-chat-hf",
          kv_cache_dtype="fp8",
          quantization_param_path="./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json")

prompt = "London is the capital of"

out = llm.generate(prompt, sampling_params)[0].outputs[0].text

print(out)

まとめ

FP8フォーマットを使用することで、LLMのメモリ使用量を大幅に削減し、長いコンテクストにも対応可能になります。FP8 E5M2およびE4M3 KVキャッシュは、精度を保ちながらメモリ効率を向上させるための強力なツールです。

ロングデータのLLM解析のために、数百GBデータをゴリゴリ扱うようになるのも、まもなくでしょう。
数百GBの特徴量ベクトルデータを、ニューラルネットワークに何回も通すことで、人間の脳では推論しきれない高度な課題を解析し、問題解決を得ようとすると思います。特に先端科学の分野や、資本がからむ経済分野では、推論競争になるので、いかに大容量で複雑なデータをLLMで解析するか、については、引き続き追っていきたいと思います。


いいなと思ったら応援しよう!