gradio 入門 (2) - Interface
「gradio」のInterfaceの使い方をまとめました。
前回
1. Interfaceの状態
gradio のInterfaceの状態には、「グローバル状態」と「セッション状態」があります。
1-2. グローバル状態
「グローバル状態」は、関数呼び出しの外で変数を作成し、関数内でその変数にアクセスします。たとえば、大きなモデルを関数外でロードし、それを関数内で使用すると、すべての関数呼び出しでモデルを再ロードする必要がなくなります。
import gradio as gr
# グローバル状態
scores = []
# 上位3つの値を保持する関数
def track_score(score):
scores.append(score)
top_scores = sorted(scores, reverse=True)[:3]
return top_scores
# Interfaceの作成
demo = gr.Interface(
track_score,
gr.Number(label="Score"),
gr.JSON(label="Top Scores")
)
# 起動
demo.launch()
上記コードでは、スコア配列はすべてのユーザー間で共有されます。 複数のユーザーがこのデモにアクセスしている場合、それらのユーザーのスコアはすべて同じリストに追加されます。
1-2. セッション状態
「セッション状態」は、ページセッション内でデータが永続化されます。ただし、異なるユーザー間ではデータが共有されません。
「セッション状態」の使用手順は、次のとおりです。
チャットボットは、「セッション状態」が必要な例です。ユーザーの会話履歴を「グローバル状態」に保存することはできません。異なるユーザー間で会話履歴がまざってしまうためです。
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# トークナイザーとモデルの準備
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
# user関数 (history:セッション状態)
def user(message, history):
return "", history + [[message, None]]
# bot関数 (history:セッション状態)
def bot(history):
user_message = history[-1][0]
new_user_input_ids = tokenizer.encode(
user_message + tokenizer.eos_token, return_tensors="pt"
)
# レスポンスの生成
bot_input_ids = torch.cat([torch.LongTensor([]), new_user_input_ids], dim=-1)
response = model.generate(
bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id
).tolist()
# トークンをテキストに変換し、応答を行に分割
response = tokenizer.decode(response[0]).split("<|endoftext|>")
response = [
(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)
]
history[-1] = response[0]
return history
# Blocksの作成
with gr.Blocks() as demo:
# UI
chatbot = gr.Chatbot()
msg = gr.Textbox()
clear = gr.Button("Clear")
# イベントリスナー
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, chatbot, chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
# 起動
demo.launch()
submit後に状態 history がどのように維持されるかに注目してください。user関数で入力 [msg, chatbot]・出力 [msg, chatbot] を処理した後、bot関数で入力 chatbot ・出力 chatbot を処理しています。
2. Reactiveインターフェイス
「Reactiveインターフェイス」は、データを自動的に更新または継続的にストリーミングするインターフェイスです。
2-1. live=True
「live=True」を指定すると、ユーザー入力が変更されるとすぐ、Interfaceが自動的に更新されます。submitボタンはなくなります
import gradio as gr
# calculator関数
def calculator(num1, operation, num2):
if operation == "add":
return num1 + num2
elif operation == "subtract":
return num1 - num2
elif operation == "multiply":
return num1 * num2
elif operation == "divide":
return num1 / num2
# Interfaceの作成
demo = gr.Interface(
calculator,
[
"number",
gr.Radio(["add", "subtract", "multiply", "divide"]),
"number"
],
"number",
live=True,
)
# 起動
demo.launch()
2-2. streaming=True
マイクモードのAudio や WebカメラモードのImage で「streaming=True」を指定すると、データが継続的にバックエンドに送信され、Interfaceが自動的に更新されます。
import gradio as gr
import numpy as np
# 画像反転の関数
def flip(im):
return np.flipud(im)
# Interfaceの作成
demo = gr.Interface(
flip,
gr.Image(source="webcam", streaming=True),
"image",
live=True
)
# 起動
demo.launch()
3. 4種類のインターフェース
gradioには次の4種類のインタフェースがあります。
3-1. 標準デモ
inputsとoutputsに別のコンポーネントを指定します。
import gradio as gr
import numpy as np
# セピア化の関数
def sepia(input_img):
sepia_filter = np.array([
[0.393, 0.769, 0.189],
[0.349, 0.686, 0.168],
[0.272, 0.534, 0.131]
])
sepia_img = input_img.dot(sepia_filter.T)
sepia_img /= sepia_img.max()
return sepia_img
# Interfaceの作成
demo = gr.Interface(
sepia,
gr.Image(shape=(200, 200)),
"image"
)
# 起動
demo.launch()
3-2. 出力専用デモ
inputsにNoneを指定します。
import time
import gradio as gr
# 関数
def fake_gan():
time.sleep(1)
images = [
"https://images.unsplash.com/photo-1507003211169-0a1dd7228f2d?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=387&q=80",
"https://images.unsplash.com/photo-1554151228-14d9def656e4?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=386&q=80",
"https://images.unsplash.com/photo-1542909168-82c3e7fdca5c?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8aHVtYW4lMjBmYWNlfGVufDB8fDB8fA%3D%3D&w=1000&q=80",
]
return images
# Interfaceの作成
demo = gr.Interface(
fn=fake_gan,
inputs=None,
outputs=gr.Gallery(label="Generated Images").style(grid=[2]),
title="FD-GAN",
description="This is a fake demo of a GAN. In reality, the images are randomly chosen from Unsplash.",
)
# 起動
demo.launch()
3-3. 入力専用デモ
outputsにNoneを指定します。
import random
import string
import gradio as gr
# 画像を保存
def save_image_random_name(image):
random_string = ''.join(random.choices(string.ascii_letters, k=20)) + '.png'
image.save(random_string)
print(f"Saved image to {random_string}!")
# Interfaceの作成
demo = gr.Interface(
fn=save_image_random_name,
inputs=gr.Image(type="pil"),
outputs=None,
)
# 起動
demo.launch()
3-4. 統合デモ
inputsとoutputsに同じコンポーネントを指定します。
import gradio as gr
from transformers import pipeline
# パイプラインの準備
generator = pipeline('text-generation', model = 'gpt2')
# テキスト生成の関数
def generate_text(text_prompt):
response = generator(text_prompt, max_length = 30, num_return_sequences=5)
return response[0]['generated_text']
# テキストボックスの準備
textbox = gr.Textbox()
# Interfaceの作成
demo = gr.Interface(
generate_text,
textbox,
textbox
)
# 起動
demo.launch()