見出し画像

novel-gptで小説を書く、gradioで簡単に長文を生成できるよにした。

japanese-novel-gptはAI BunChoで使用されている小説用LLMです。今回は小規模なGPUでも動くように8bit量子化を行い、gradioによる操作性の向上を目指して実装を試しました。オリジナルのコードをベースに長文も作成できるよう、UIで変数の調整もできるようにしています。

japanese-novel-gptはnpakaさんをはじめ、多くの方が何度も記事にされていますし、公式のコードは容易に動かすことができます。一方で長文の生成は少し工夫が必要です。またコマンドラインに出力される文字列は非常に読みずらいので、適時改行コードを入れて可読性も向上させました。

環境設定

まずは準備です。公式にある通りにインストールをします。8bit量子化をするので、bitsandbytesもインストールです。japanese-novel-gptは6Bもあり、トークンを大きくして長文を生成するためには16G-VRAMが必要です。今回は8bit量子化で12GクラスのGPUでも動くようにしています。

pip install transformers sentencepiece accelerate bitsandbytes

初期設定部分

from transformers import GPTJForCausalLM, AlbertTokenizer
import torch
import  gradio as gr

tokenizer = AlbertTokenizer.from_pretrained(
        'AIBunCho/japanese-novel-gpt-j-6b',
         keep_accents=True,
         remove_space=False)
#量子化なし VRAM=13.5G
#model = GPTJForCausalLM.from_pretrained("AIBunCho/japanese-novel-gpt-j-6b",
#        torch_dtype=torch.float16,
#        low_cpu_mem_usage=True)
#model.half()
#model.eval()
#if torch.cuda.is_available():
#    model = model.to("cuda")

#8bit量子化あり VRAM=7.8G@ 200トークン、8.2G@500トークン
model = GPTJForCausalLM.from_pretrained("AIBunCho/japanese-novel-gpt-j-6b",
        load_in_8bit=True,
        torch_dtype=torch.float16,
        device_map='auto')

8Bit量子化をしないコードはコメントアウトしてあります。特に難しい部分は無く、モデルも自動的にダウンロードされます。

長文を生成する工夫

max_new_tokensとmin-lengthを大きくすれば長文が書けるわけですが、生成進行状況がわかりずらいのと、大きくするとVRAMが不足するために、短く分けながら何度か連続で生成できるよう工夫したコードになっています。分割生成をする部分が以下の関数になります。入力ブロンプトを使いながら生成ごとに古い文章を削除しながら、新しい文章を追加生成する仕組みになっています。コード中にコメントを記載しているので難解な部分は無いと思います。主にテキストの処理を行いながら指定回数の生成を行う仕組みです。デフォルトでは5回の生成をトークン100で回しながら各ターンで3行ずつ古い文章を削除しています。
関数の最後は連続した生成分では読みにくいので、。や」「の部分に改行コードを挿入してgradioに渡しています。

#プロンプトの例
#ある朝、事件が起こりました。たしか、あれは3年前のことです。その時、天気は荒れて大雨でした。

def pre_generate(input_prompt,   gen_count,  temperature_val,  min_len,clr_len):
        sum_txt=""
        gen_count  =int(gen_count)
        min_len       =int(min_len)
        clr_len         =int(clr_len)
        print("gen_count=", gen_count,"min_len=", min_len ,"clr_len=",clr_len)
        org_input_prompt=input_prompt
        for i in range(gen_count ):
                print("ターン=",i)
                #input_promptの続きをgenerate()で生成、min_len x 3文字程度
                gen_out = generate(input_prompt,   temperature_val,  min_len)
                out_txt=gen_out.replace(input_prompt,"")       #gen_outからこのターンのプロンプトを削除(生成分の書き出しが完全に一致していることが条件)
                sum_txt=sum_txt+out_txt                                             #out_txtには新たに生成された文が残っているので、以前までの文に追加
                #生成された文からターン毎削除行数で指定した行を削除
                gen_txt=gen_out.replace(org_input_prompt,"")#生成された文から、入力したプロンプトのみ削除
                print("gen_txt     =",gen_txt)
                out_form=gen_txt.replace("。","。\n").replace("」「","」\n「")
                new_form=out_form.split("\n")[clr_len:]
                #新たに入力したプロンプトと削除したあとの文を結合して次のターンのプロンプトにする
                input_prompt=org_input_prompt + "".join(new_form)
                input_prompt_save=input_prompt
        out_txt=org_input_prompt + sum_txt    #入力したプロンプトの後ろに各ターンで生成された結合文を結合して最終的な生成文にする
        print("out_txt=",out_txt)
        #表示か見やすいように、。や 」「 には改行を挟む
        txt_out=out_txt.replace("。","。\n").replace("」「","」\n「")
        return  txt_out  

Generation

この部分はオリジナル通りです。max_new_tokensとmin-length、 temperatureを変数として受け取れるように改造しています。

def generate(prompt,   temperature_val,  min_len):
        input_ids = tokenizer.encode(
                prompt,
                add_special_tokens=False,
                return_tensors="pt"
                ).cuda()
        tokens = model.generate(
                input_ids.to(device=model.device),
                max_new_tokens=min_len,
                min_length= min_len,
                temperature=temperature_val,
                top_p=0.9,
                repetition_penalty=1.2,
                do_sample=True,
                pad_token_id=tokenizer.pad_token_id,
                bos_token_id=tokenizer.bos_token_id,
                eos_token_id=tokenizer.eos_token_id )
        out = tokenizer.decode(tokens[0], skip_special_tokens=True)
        return out

Gradio-UI部分

今回もBlocksを使用して記述しています。

全体のコード

from transformers import GPTJForCausalLM, AlbertTokenizer
import torch
import  gradio as gr

tokenizer = AlbertTokenizer.from_pretrained(
        'AIBunCho/japanese-novel-gpt-j-6b',
         keep_accents=True,
         remove_space=False)
#量子化なし VRAM=13.5G
#model = GPTJForCausalLM.from_pretrained("AIBunCho/japanese-novel-gpt-j-6b",
#        torch_dtype=torch.float16,
#        low_cpu_mem_usage=True)
#model.half()
#model.eval()
#if torch.cuda.is_available():
#    model = model.to("cuda")

#8bit量子化あり VRAM=7.8G@ 200トークン、8.2G@500トークン
model = GPTJForCausalLM.from_pretrained("AIBunCho/japanese-novel-gpt-j-6b",
        load_in_8bit=True,
        torch_dtype=torch.float16,
        device_map='auto')

#プロンプトの例
#ある朝、事件が起こりました。たしか、あれは3年前のことです。その時、天気は荒れて大雨でした。

def pre_generate(input_prompt,   gen_count,  temperature_val,  min_len,clr_len):
        sum_txt=""
        gen_count  =int(gen_count)
        min_len       =int(min_len)
        clr_len         =int(clr_len)
        print("gen_count=", gen_count,"min_len=", min_len ,"clr_len=",clr_len)
        org_input_prompt=input_prompt
        for i in range(gen_count ):
                print("ターン=",i)
                #input_promptの続きをgenerate()で生成、min_len x 3文字程度
                gen_out = generate(input_prompt,   temperature_val,  min_len)
                out_txt=gen_out.replace(input_prompt,"")       #gen_outからこのターンのプロンプトを削除(生成分の書き出しが完全に一致していることが条件)
                sum_txt=sum_txt+out_txt                                             #out_txtには新たに生成された文が残っているので、以前までの文に追加
                #生成された文からターン毎削除行数で指定した行を削除
                gen_txt=gen_out.replace(org_input_prompt,"")#生成された文から、入力したプロンプトのみ削除
                print("gen_txt     =",gen_txt)
                out_form=gen_txt.replace("。","。\n").replace("」「","」\n「")
                new_form=out_form.split("\n")[clr_len:]
                #新たに入力したプロンプトと削除したあとの文を結合して次のターンのプロンプトにする
                input_prompt=org_input_prompt + "".join(new_form)
                input_prompt_save=input_prompt
        out_txt=org_input_prompt + sum_txt    #入力したプロンプトの後ろに各ターンで生成された結合文を結合して最終的な生成文にする
        print("out_txt=",out_txt)
        #表示か見やすいように、。や 」「 には改行を挟む
        txt_out=out_txt.replace("。","。\n").replace("」「","」\n「")
        return  txt_out    

def generate(prompt,   temperature_val,  min_len):
        input_ids = tokenizer.encode(
                prompt,
                add_special_tokens=False,
                return_tensors="pt"
                ).cuda()
        tokens = model.generate(
                input_ids.to(device=model.device),
                max_new_tokens=min_len,
                min_length= min_len,
                temperature=temperature_val,
                top_p=0.9,
                repetition_penalty=1.2,
                do_sample=True,
                pad_token_id=tokenizer.pad_token_id,
                bos_token_id=tokenizer.bos_token_id,
                eos_token_id=tokenizer.eos_token_id )
        out = tokenizer.decode(tokens[0], skip_special_tokens=True)
        return out

with gr.Blocks() as webui:
      with gr.Row():
          with gr.Column():
              sys_prompt    = gr.Textbox(label="あらすじ/書き出し",lines=10, placeholder=" あらすじや、小説の書き出しを入力してください")
              with gr.Row():
                      gen_count       = gr.Number(5, label="生成ターン数")
                      clr_len                = gr.Number(3, label="ターン毎削除行数")
                      temperature  = gr.Number(0.7, label="temperature")
                      min_length     = gr.Number(100, label="min_length(x3文字程度")
              with gr.Row():
                  prompt_inpu    = gr.Button("Submit prompt",variant="primary")
                  #continue_next = gr.Button("続きを読む")
          with gr.Column():
              out_data=[gr.Textbox(label="生成小説")]
      prompt_inpu.click(pre_generate, inputs=[sys_prompt,  gen_count, temperature,  min_length ,clr_len], outputs=out_data )
webui.launch()

入出力の例

サンプルのプロンプトをコード中にコメントで入れています。ミステリーを生成するための書き出しです。

「ある朝、事件が起こりました。たしか、あれは3年前のことです。その時、天気は荒れて大雨でした。」

コンソールには各ターンでの生成文が表示され、進行状況がわかるようになっています。UIは以下の通りです。
左側に書き出し文やあらすじを入力します。こればLLMへのプロンプトになります。下部には変数入力部があります。
変数ターン数: 何回生成するか指定します。
ターン削除行数:ターンごとに削除する行数を指定します。大きくすると
        トークンが少なくなってmin-lenngthを大きくできますが
        物語の連続性に問題が生じる可能性があります。
tempertature: 他のLLM同様に生成文のバリエーションを指定します。
各入力BOXでは空白があるとエラーになります。入力した数値の後ろに空白がある場合もエラーです。

右側には改行で読みやすくなった文章が表示されています。

min-lenngthを500ぐらいにしてターンを10にするとずいぶんと長い文章が生成されます。その分時間もかかるので、覚悟が必要です。読むにもかなりの時間を要します。

気が付いたこと

物語は完結しません。途中で途切れた終わり方をします。なので続きが読みたくなります。「続きを読む」ボタンの設置が必要ですが、今回は評価を兼ねた実装なので組み込んでいません。ループ中の変数を記憶しておいてループを再開させるだけですが、導入部分から生成部分までを削除して、続きだけを表示させるために工夫が必要でしょう。改良時には設けたいと思います。

まとめ

とても興味深い小説が書けるようです。一部で辻褄が合わないストーリーになる事があったり、専門的な言葉が入る導入は苦手なような気がします。人の手を少し入れると読み応えのある小説に仕上がるのではないかと思いました。