見出し画像

QLoRAチューニングモデルをAPIから遊ぶメモ

このメモを読むと

・QLoRA学習済みモデルをAPI経由で呼び出せる
・ストリーム出力ができるようになる

検証環境

・Windows11
・VRAM24GB
・ローカル(Anaconda)
・python3.10
・2023/7/B時点

事前準備

Anacondaを使うメモ|おれっち (note.com)
Gitを使うメモ|おれっち (note.com)

APIとは

あるソフトウェアから別のソフトウェアに機能を提供するための「接続口」のようなもの。便利らしい。ChatGPTをローカルPC上のスクリプトで動かすときにもAPI経由でアレコレをやり取りしているのです。
こちらの記事でローカルLLMをAPI化する方法が紹介されていたので、QLoRAチューニングモデルを呼び出して遊んでみます。

すること

 ・QLoRAチューニング
 ・API経由で文章生成

環境構築

とても簡単です!

1. 仮想環境を作成し、環境切替

conda create -n apitest python=3.10
activate apitest

2. 追加パッケージのインストール

pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install https://github.com/jllllll/bitsandbytes-windows-webui/raw/main/bitsandbytes-0.39.0-py3-none-any.whl
pip install git+https://github.com/huggingface/accelerate.git
pip install git+https://github.com/huggingface/transformers.git
pip install git+https://github.com/huggingface/peft.git
pip install datasets sentencepiece protobuf==3.20.0
pip install fastapi uvicorn

完了です!

QLoRAチューニング

下記記事の"学習"を行うことで作成できます。

モデルはこちらをお借りしました。
 ベースモデル:cyberagent/open-calm-7b
 データセット:bbz662bbz/databricks-dolly-15k-ja-gozaru

API経由で文章生成


下記三つを作成
 APIファイル
 文章生成用スクリプト
 API呼び出しスクリプト

APIファイル(例:gen_api.py)

from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from generate import textgen, request #文章生成用スクリプト

# 基本パラメータ
base_model = "cyberagent/open-calm-7b"
peft_name = "test-Ocalm-7b"

llm = textgen(base_model, peft_name)
app = FastAPI()

@app.get("/")
def read_root():
    return {"Hello": "World"}

@app.get("/chat/")
def read_root():
    return {"status": "ready to chat"}

@app.post("/chat/")
async def response(request:request):
    return await llm.generate(request)

@app.post("/chat-stream/")
async def response_stream(request:request):
    return StreamingResponse(llm.generate_stream(request))

文章生成用スクリプト(例:generate.py)

import asyncio
import torch
import json
from threading import Thread
from typing import AsyncIterator
from pydantic import BaseModel
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer

class request(BaseModel):
    messages: list
    role: bool = True
    max_new_tokens: int = 512
    temperature: float = 0.8

class textgen:

    def __init__(self,base_model,peft_name):
        # ベースモデル量子化パラ設定
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16
        )
        self.tokenizer = AutoTokenizer.from_pretrained(base_model)
        self.eos_token = self.tokenizer.decode([self.tokenizer.eos_token_id])
        self.model = AutoModelForCausalLM.from_pretrained(base_model, quantization_config=bnb_config, device_map="auto")

        self.model = PeftModel.from_pretrained(self.model, peft_name)
        self.model.eval()# 評価モード

    def request2prompt(self,request):
        if request.role:
            prompt = [
                f"{uttr['speaker']}: {uttr['text']}"
                for uttr in request.messages
            ]
            # print(prompt)
            prompt = "<NL>".join(prompt)
            prompt = (
                prompt
                + "<NL>"
                + "システム: "
            )

        else:
            prompt = [
                f"{uttr['text']}"
                for uttr in request.messages
            ]
            prompt = "<NL>".join(prompt) + "<NL>"
            
        return prompt
    
    async def generate_stream(self,request) -> AsyncIterator[str]:

        prompt = self.request2prompt(request)
        input_ids = self.tokenizer(prompt, 
                          return_tensors="pt", 
                          truncation=True, 
                          add_special_tokens=False).input_ids.cuda()
        streamer = TextIteratorStreamer(self.tokenizer)

        generation_kwargs = dict(
            input_ids=input_ids,
            streamer=streamer,
            max_new_tokens=request.max_new_tokens,
            do_sample=True,
            temperature=request.temperature,
            pad_token_id=self.tokenizer.pad_token_id,
            bos_token_id=self.tokenizer.bos_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
            bad_words_ids=[[self.tokenizer.bos_token_id]],
            num_return_sequences=1,
        )

        thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
        thread.start()

        for output in streamer: 
            if not output:
                continue
            print(output)
            if self.eos_token not in output:
                yield json.dumps({
                    "speaker": "システム",
                    "text":output.replace("<NL>", "\n"),
                    "continue":True}, ensure_ascii=False)
            else:
                yield json.dumps({
                    "speaker": "システム",
                    "text":output.replace(self.eos_token, ""),
                    "continue":False}, 
                    ensure_ascii=False)
            await asyncio.sleep(0)
            
    async def generate(self,request):

        prompt = self.request2prompt(request)
        input_ids = self.tokenizer(prompt, 
                          return_tensors="pt", 
                          truncation=True, 
                          add_special_tokens=False).input_ids.cuda()

        with torch.no_grad():
            output_ids = self.model.generate(
                input_ids=input_ids,
                do_sample=True,
                max_new_tokens=request.max_new_tokens,
                temperature=request.temperature,
                pad_token_id=self.tokenizer.pad_token_id,
                bos_token_id=self.tokenizer.bos_token_id,
                eos_token_id=self.tokenizer.eos_token_id
            )

        output = self.tokenizer.decode(output_ids.tolist()[0][input_ids.size(1):])
        output = output.replace("<NL>", "\n").replace(self.eos_token, "")
        res_message = {
            "speaker": "システム",
            "text": output
        }

        return res_message 

API呼び出しスクリプト(post.py)

import requests
import json
import sys

def post_question(question):
    url = "http://127.0.0.1:8000/chat-stream/"
    headers = {
        'accept': 'application/json',
        'Content-Type': 'application/json',
    }
    data = {
        "messages": [{
            "speaker": "ユーザー",
            "text": question
        }],
        "role": True,
        "max_new_tokens": 512,
        "temperature": 0.8
    }
    response = requests.post(url, headers=headers, json=data, stream=True)

    decoder = json.JSONDecoder()
    buffer = ""
    for chunk in response.iter_content(chunk_size=1024):
        buffer += chunk.decode()
        while buffer:
            try:
                result, index = decoder.raw_decode(buffer)
                text = result.get('text', '')
                print(text,end="",flush=True)
                if not result.get('continue'):
                    return
                buffer = buffer[index:]
            except ValueError:
                # Not enough data to decode, fetch more
                break

if __name__ == "__main__":
    while True:
        question = input('\n'+"Question: ")
        if question =="":
            break
        post_question(question)

APIを起動(仮想環境にて

uvicorn gen_api:app --reload

API呼び出しスクリプトを実行(仮想環境にて

python post.py

生成できた!

おわり

APIできた!APIの動作確認にはFastAPI - Swagger UIをよく使う。

参考資料

最初のステップ - FastAPI (tiangolo.com)

いいなと思ったら応援しよう!