見出し画像

novel-gptで小説を書く、gradioで簡単に長文を生成できるよにしたー(2)

昨日の記事で続きを読むが無いことを書きました。しかし、これが無いと、とても不便です。なので、連投になりますが、「続きを読む」ボタンを追加するとともに、「全文を表示」ボタンと「テキストファイルを作成」ボタンも追加して、概ね使えるアプリに近づけました。

新しいGUIです。ボタンが3個追加されています。

追加部分

def pre_generate(input_prompt,   gen_count,  temperature_val,  min_len,clr_len):
        global sum_txt
        global  input_prompt_save
        global  all_txt
        sum_txt=""
        input_prompt_save =input_prompt
        all_txt=""
        out_txt , input_prompt_save =txt_gen_count(input_prompt, gen_count,  temperature_val,  min_len,clr_len,  sum_txt, input_prompt_save)
        sum_txt=out_txt
        out_txt=input_prompt + sum_txt    #入力したプロンプトの後ろに各ターンで生成された結合文を結合して最終的な生成文にする
        all_txt=  out_txt
        txt_out=out_txt.replace("。","。\n").replace("」「","」\n「") #表示か見やすいように、。や 」「 には改行を挟む
        print("####新規",txt_out)
        return  txt_out

def continue_txt(input_prompt,  gen_count,  temperature_val,  min_len,clr_len):
       global sum_txt
       global  input_prompt_save
       global  all_txt
       pre_sum_txt=sum_txt
       txt_out, input_prompt_save=txt_gen_count(input_prompt, gen_count,  temperature_val,  min_len,clr_len,  sum_txt,input_prompt_save)
       sum_txt =  txt_out
       all_txt = input_prompt + txt_out       #全文
       txt_out=txt_out.replace(pre_sum_txt,"")
       txt_out=txt_out.replace("。","。\n").replace("」「","」\n「") #表示か見やすいように、。や 」「 には改行を挟む
       print("####新規",txt_out)
       return  txt_out 

新規の作成時のパラメータを txt_gen_countに渡す関数と、続きを生成する場合に以前のパラメータを引き継ぐ関数を設けました。これに伴い txt_gen_count内部も少し変更しています。

全文表示、ファイル化

def disp_all_txt():
        global  all_txt
        txt_out=all_txt.replace("。","。\n").replace("」「","」\n「") #表示か見やすいように、。や 」「 には改行を挟む
        return txt_out

def txt2file():
        global  all_txt
        txt_out=all_txt.replace("。","。\n").replace("」「","」\n「") #表示か見やすいように、。や 」「 には改行を挟む
        file_name = "gen_novel.txt"
        with open( file_name , "w") as file:
                file.write(txt_out)
        return file_name

gradioから呼びだされる簡単な関数です。「テキストファイル作成」ボタンをクリックするとtxt2file()が呼び出され、起動時のディレクトリ内に
gen_novel.tx
が生成されます。改行を追加した読みやすいフォーマットです。全文が書き出されます。

全コード

from transformers import GPTJForCausalLM, AlbertTokenizer
import torch
import  gradio as gr

tokenizer = AlbertTokenizer.from_pretrained(
        'AIBunCho/japanese-novel-gpt-j-6b',
         keep_accents=True,
         remove_space=False)
#量子化なし VRAM=13.5G
#model = GPTJForCausalLM.from_pretrained("AIBunCho/japanese-novel-gpt-j-6b",
#        torch_dtype=torch.float16,
#        low_cpu_mem_usage=True)
#model.half()
#model.eval()
#if torch.cuda.is_available():
#    model = model.to("cuda")

#8bit量子化あり VRAM=7.8G@ 200トークン、8.2G@500トークン
model = GPTJForCausalLM.from_pretrained("AIBunCho/japanese-novel-gpt-j-6b",
        load_in_8bit=True,
        torch_dtype=torch.float16,
        device_map='auto')

#プロンプトの例
#ある朝、事件が起こりました。たしか、あれは3年前のことです。その時、天気は荒れて大雨でした。

input_prompt_save =""
sum_txt=""
all_txt=""

def pre_generate(input_prompt,   gen_count,  temperature_val,  min_len,clr_len):
        global sum_txt
        global  input_prompt_save
        global  all_txt
        sum_txt=""
        input_prompt_save =input_prompt
        all_txt=""
        out_txt , input_prompt_save =txt_gen_count(input_prompt, gen_count,  temperature_val,  min_len,clr_len,  sum_txt, input_prompt_save)
        sum_txt=out_txt
        out_txt=input_prompt + sum_txt    #入力したプロンプトの後ろに各ターンで生成された結合文を結合して最終的な生成文にする
        all_txt=  out_txt
        txt_out=out_txt.replace("。","。\n").replace("」「","」\n「") #表示か見やすいように、。や 」「 には改行を挟む
        print("####新規",txt_out)
        return  txt_out

def continue_txt(input_prompt,  gen_count,  temperature_val,  min_len,clr_len):
       global sum_txt
       global  input_prompt_save
       global  all_txt
       pre_sum_txt=sum_txt
       txt_out, input_prompt_save=txt_gen_count(input_prompt, gen_count,  temperature_val,  min_len,clr_len,  sum_txt,input_prompt_save)
       sum_txt =  txt_out
       all_txt = input_prompt + txt_out       #全文
       txt_out=txt_out.replace(pre_sum_txt,"")
       txt_out=txt_out.replace("。","。\n").replace("」「","」\n「") #表示か見やすいように、。や 」「 には改行を挟む
       print("####新規",txt_out)
       return  txt_out 

def txt_gen_count(input_prompt,  gen_count,  temperature_val,  min_len,clr_len,  sum_txt ,input_prompt_save):
        gen_count  =int(gen_count)
        min_len       =int(min_len)
        clr_len         =int(clr_len)
        org_input_prompt=input_prompt
        input_prompt = input_prompt_save
        print("gen_count=", gen_count,"min_len=", min_len ,"clr_len=",clr_len)
        for i in range(gen_count ):
                print("ターン=",i)
                #input_promptの続きをgenerate()で生成、min_len x 3文字程度
                gen_out = generate(input_prompt,   temperature_val,  min_len)
                out_txt=gen_out.replace(input_prompt,"")       #gen_outからこのターンのプロンプトを削除(生成分の書き出しが完全に一致していることが条件)
                sum_txt=sum_txt+out_txt                                             #out_txtには新たに生成された文が残っているので、以前までの文に追加
                #生成された文からターン毎削除行数で指定した行を削除
                gen_txt=gen_out.replace(org_input_prompt,"")#生成された文から、入力したプロンプトのみ削除
                print("gen_txt     =",gen_txt)
                out_form=gen_txt.replace("。","。\n").replace("」「","」\n「")
                new_form=out_form.split("\n")[clr_len:]
                #新たに入力したプロンプトと削除したあとの文を結合して次のターンのプロンプトにする
                input_prompt=org_input_prompt + "".join(new_form)
                input_prompt_save=input_prompt  #input_prompt_saveは続きを読むに利用する
        print(">>>> sum_txt =",sum_txt )
        return  sum_txt , input_prompt_save

def generate(prompt,   temperature_val,  min_len):
        input_ids = tokenizer.encode(
                prompt,
                add_special_tokens=False,
                return_tensors="pt"
                ).cuda()
        tokens = model.generate(
                input_ids.to(device=model.device),
                max_new_tokens=min_len,
                min_length= min_len,
                temperature=temperature_val,
                top_p=0.9,
                repetition_penalty=1.2,
                do_sample=True,
                pad_token_id=tokenizer.pad_token_id,
                bos_token_id=tokenizer.bos_token_id,
                eos_token_id=tokenizer.eos_token_id )
        out = tokenizer.decode(tokens[0], skip_special_tokens=True)
        return out

def disp_all_txt():
        global  all_txt
        txt_out=all_txt.replace("。","。\n").replace("」「","」\n「") #表示か見やすいように、。や 」「 には改行を挟む
        return txt_out

def txt2file():
        global  all_txt
        txt_out=all_txt.replace("。","。\n").replace("」「","」\n「") #表示か見やすいように、。や 」「 には改行を挟む
        file_name = "gen_novel.txt"
        with open( file_name , "w") as file:
                file.write(txt_out)
        return file_name

# GradioのUIを定義します
with gr.Blocks() as webui:
      with gr.Row():
          with gr.Column():
              sys_prompt    = gr.Textbox(label="あらすじ/書き出し",lines=10, placeholder=" あらすじや、小説の書き出しを入力してください")
              with gr.Row():
                      gen_count       = gr.Number(5, label="生成ターン数")
                      clr_len                = gr.Number(3, label="ターン毎削除行数")
                      temperature  = gr.Number(0.7, label="temperature")
                      min_length     = gr.Number(100, label="min_length(x3文字程度")
              with gr.Row():
                  prompt_inpu    = gr.Button("新規作成",variant="primary")
                  continue_next = gr.Button("続きを読む")
                  disp_all = gr.Button("全文を表示")
          with gr.Column():
              out_data=[gr.Textbox(label="生成小説")]
              save_txt = gr.Button("テキストファイル作成")
              save_txt.click(txt2file)
      # 各ボタンがクリックされたときの処理を設定します
      prompt_inpu.click(pre_generate, inputs=[sys_prompt,  gen_count, temperature,  min_length ,clr_len], outputs=out_data )
      continue_next.click(continue_txt, inputs=[sys_prompt, gen_count, temperature,  min_length ,clr_len], outputs=out_data)
      disp_all.click(disp_all_txt , outputs=out_data)
# Gradioアプリケーションを起動します
webui.launch()

エンドレスの小説を書くことができます。終わらす方法はわかりませんね。