novel-gptで小説を書く、gradioで簡単に長文を生成できるよにしたー(2)
昨日の記事で続きを読むが無いことを書きました。しかし、これが無いと、とても不便です。なので、連投になりますが、「続きを読む」ボタンを追加するとともに、「全文を表示」ボタンと「テキストファイルを作成」ボタンも追加して、概ね使えるアプリに近づけました。
新しいGUIです。ボタンが3個追加されています。
追加部分
def pre_generate(input_prompt, gen_count, temperature_val, min_len,clr_len):
global sum_txt
global input_prompt_save
global all_txt
sum_txt=""
input_prompt_save =input_prompt
all_txt=""
out_txt , input_prompt_save =txt_gen_count(input_prompt, gen_count, temperature_val, min_len,clr_len, sum_txt, input_prompt_save)
sum_txt=out_txt
out_txt=input_prompt + sum_txt #入力したプロンプトの後ろに各ターンで生成された結合文を結合して最終的な生成文にする
all_txt= out_txt
txt_out=out_txt.replace("。","。\n").replace("」「","」\n「") #表示か見やすいように、。や 」「 には改行を挟む
print("####新規",txt_out)
return txt_out
def continue_txt(input_prompt, gen_count, temperature_val, min_len,clr_len):
global sum_txt
global input_prompt_save
global all_txt
pre_sum_txt=sum_txt
txt_out, input_prompt_save=txt_gen_count(input_prompt, gen_count, temperature_val, min_len,clr_len, sum_txt,input_prompt_save)
sum_txt = txt_out
all_txt = input_prompt + txt_out #全文
txt_out=txt_out.replace(pre_sum_txt,"")
txt_out=txt_out.replace("。","。\n").replace("」「","」\n「") #表示か見やすいように、。や 」「 には改行を挟む
print("####新規",txt_out)
return txt_out
新規の作成時のパラメータを txt_gen_countに渡す関数と、続きを生成する場合に以前のパラメータを引き継ぐ関数を設けました。これに伴い txt_gen_count内部も少し変更しています。
全文表示、ファイル化
def disp_all_txt():
global all_txt
txt_out=all_txt.replace("。","。\n").replace("」「","」\n「") #表示か見やすいように、。や 」「 には改行を挟む
return txt_out
def txt2file():
global all_txt
txt_out=all_txt.replace("。","。\n").replace("」「","」\n「") #表示か見やすいように、。や 」「 には改行を挟む
file_name = "gen_novel.txt"
with open( file_name , "w") as file:
file.write(txt_out)
return file_name
gradioから呼びだされる簡単な関数です。「テキストファイル作成」ボタンをクリックするとtxt2file()が呼び出され、起動時のディレクトリ内に
gen_novel.tx
が生成されます。改行を追加した読みやすいフォーマットです。全文が書き出されます。
全コード
from transformers import GPTJForCausalLM, AlbertTokenizer
import torch
import gradio as gr
tokenizer = AlbertTokenizer.from_pretrained(
'AIBunCho/japanese-novel-gpt-j-6b',
keep_accents=True,
remove_space=False)
#量子化なし VRAM=13.5G
#model = GPTJForCausalLM.from_pretrained("AIBunCho/japanese-novel-gpt-j-6b",
# torch_dtype=torch.float16,
# low_cpu_mem_usage=True)
#model.half()
#model.eval()
#if torch.cuda.is_available():
# model = model.to("cuda")
#8bit量子化あり VRAM=7.8G@ 200トークン、8.2G@500トークン
model = GPTJForCausalLM.from_pretrained("AIBunCho/japanese-novel-gpt-j-6b",
load_in_8bit=True,
torch_dtype=torch.float16,
device_map='auto')
#プロンプトの例
#ある朝、事件が起こりました。たしか、あれは3年前のことです。その時、天気は荒れて大雨でした。
input_prompt_save =""
sum_txt=""
all_txt=""
def pre_generate(input_prompt, gen_count, temperature_val, min_len,clr_len):
global sum_txt
global input_prompt_save
global all_txt
sum_txt=""
input_prompt_save =input_prompt
all_txt=""
out_txt , input_prompt_save =txt_gen_count(input_prompt, gen_count, temperature_val, min_len,clr_len, sum_txt, input_prompt_save)
sum_txt=out_txt
out_txt=input_prompt + sum_txt #入力したプロンプトの後ろに各ターンで生成された結合文を結合して最終的な生成文にする
all_txt= out_txt
txt_out=out_txt.replace("。","。\n").replace("」「","」\n「") #表示か見やすいように、。や 」「 には改行を挟む
print("####新規",txt_out)
return txt_out
def continue_txt(input_prompt, gen_count, temperature_val, min_len,clr_len):
global sum_txt
global input_prompt_save
global all_txt
pre_sum_txt=sum_txt
txt_out, input_prompt_save=txt_gen_count(input_prompt, gen_count, temperature_val, min_len,clr_len, sum_txt,input_prompt_save)
sum_txt = txt_out
all_txt = input_prompt + txt_out #全文
txt_out=txt_out.replace(pre_sum_txt,"")
txt_out=txt_out.replace("。","。\n").replace("」「","」\n「") #表示か見やすいように、。や 」「 には改行を挟む
print("####新規",txt_out)
return txt_out
def txt_gen_count(input_prompt, gen_count, temperature_val, min_len,clr_len, sum_txt ,input_prompt_save):
gen_count =int(gen_count)
min_len =int(min_len)
clr_len =int(clr_len)
org_input_prompt=input_prompt
input_prompt = input_prompt_save
print("gen_count=", gen_count,"min_len=", min_len ,"clr_len=",clr_len)
for i in range(gen_count ):
print("ターン=",i)
#input_promptの続きをgenerate()で生成、min_len x 3文字程度
gen_out = generate(input_prompt, temperature_val, min_len)
out_txt=gen_out.replace(input_prompt,"") #gen_outからこのターンのプロンプトを削除(生成分の書き出しが完全に一致していることが条件)
sum_txt=sum_txt+out_txt #out_txtには新たに生成された文が残っているので、以前までの文に追加
#生成された文からターン毎削除行数で指定した行を削除
gen_txt=gen_out.replace(org_input_prompt,"")#生成された文から、入力したプロンプトのみ削除
print("gen_txt =",gen_txt)
out_form=gen_txt.replace("。","。\n").replace("」「","」\n「")
new_form=out_form.split("\n")[clr_len:]
#新たに入力したプロンプトと削除したあとの文を結合して次のターンのプロンプトにする
input_prompt=org_input_prompt + "".join(new_form)
input_prompt_save=input_prompt #input_prompt_saveは続きを読むに利用する
print(">>>> sum_txt =",sum_txt )
return sum_txt , input_prompt_save
def generate(prompt, temperature_val, min_len):
input_ids = tokenizer.encode(
prompt,
add_special_tokens=False,
return_tensors="pt"
).cuda()
tokens = model.generate(
input_ids.to(device=model.device),
max_new_tokens=min_len,
min_length= min_len,
temperature=temperature_val,
top_p=0.9,
repetition_penalty=1.2,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id )
out = tokenizer.decode(tokens[0], skip_special_tokens=True)
return out
def disp_all_txt():
global all_txt
txt_out=all_txt.replace("。","。\n").replace("」「","」\n「") #表示か見やすいように、。や 」「 には改行を挟む
return txt_out
def txt2file():
global all_txt
txt_out=all_txt.replace("。","。\n").replace("」「","」\n「") #表示か見やすいように、。や 」「 には改行を挟む
file_name = "gen_novel.txt"
with open( file_name , "w") as file:
file.write(txt_out)
return file_name
# GradioのUIを定義します
with gr.Blocks() as webui:
with gr.Row():
with gr.Column():
sys_prompt = gr.Textbox(label="あらすじ/書き出し",lines=10, placeholder=" あらすじや、小説の書き出しを入力してください")
with gr.Row():
gen_count = gr.Number(5, label="生成ターン数")
clr_len = gr.Number(3, label="ターン毎削除行数")
temperature = gr.Number(0.7, label="temperature")
min_length = gr.Number(100, label="min_length(x3文字程度")
with gr.Row():
prompt_inpu = gr.Button("新規作成",variant="primary")
continue_next = gr.Button("続きを読む")
disp_all = gr.Button("全文を表示")
with gr.Column():
out_data=[gr.Textbox(label="生成小説")]
save_txt = gr.Button("テキストファイル作成")
save_txt.click(txt2file)
# 各ボタンがクリックされたときの処理を設定します
prompt_inpu.click(pre_generate, inputs=[sys_prompt, gen_count, temperature, min_length ,clr_len], outputs=out_data )
continue_next.click(continue_txt, inputs=[sys_prompt, gen_count, temperature, min_length ,clr_len], outputs=out_data)
disp_all.click(disp_all_txt , outputs=out_data)
# Gradioアプリケーションを起動します
webui.launch()
エンドレスの小説を書くことができます。終わらす方法はわかりませんね。