見出し画像

gradio 入門 (2) - Interface

「gradio」のInterfaceの使い方をまとめました。

前回

1. Interfaceの状態

gradio のInterfaceの状態には、「グローバル状態」と「セッション状態」があります。

1-2. グローバル状態

「グローバル状態」は、関数呼び出しの外で変数を作成し、関数内でその変数にアクセスします。たとえば、大きなモデルを関数外でロードし、それを関数内で使用すると、すべての関数呼び出しでモデルを再ロードする必要がなくなります。

import gradio as gr

# グローバル状態
scores = []

# 上位3つの値を保持する関数
def track_score(score):
    scores.append(score)
    top_scores = sorted(scores, reverse=True)[:3]
    return top_scores

# Interfaceの作成
demo = gr.Interface(
    track_score,
    gr.Number(label="Score"),
    gr.JSON(label="Top Scores")
)

# 起動
demo.launch()

上記コードでは、スコア配列はすべてのユーザー間で共有されます。 複数のユーザーがこのデモにアクセスしている場合、それらのユーザーのスコアはすべて同じリストに追加されます。

1-2. セッション状態

「セッション状態」は、ページセッション内でデータが永続化されます。ただし、異なるユーザー間ではデータが共有されません。

「セッション状態」の使用手順は、次のとおりです。

(1) 状態を表す追加のパラメータを関数に渡す。
(2) 更新された状態を追加の戻り値を関数から返す。
(3) Interface作成時に状態入力コンポーネントと状態出力コンポーネントを追加。

チャットボットは、「セッション状態」が必要な例です。ユーザーの会話履歴を「グローバル状態」に保存することはできません。異なるユーザー間で会話履歴がまざってしまうためです。

import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# トークナイザーとモデルの準備
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")

# user関数 (history:セッション状態)
def user(message, history):
    return "", history + [[message, None]]

# bot関数 (history:セッション状態)
def bot(history):
    user_message = history[-1][0]
    new_user_input_ids = tokenizer.encode(
        user_message + tokenizer.eos_token, return_tensors="pt"
    )

    # レスポンスの生成
    bot_input_ids = torch.cat([torch.LongTensor([]), new_user_input_ids], dim=-1)
    response = model.generate(
        bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id
    ).tolist()

    # トークンをテキストに変換し、応答を行に分割
    response = tokenizer.decode(response[0]).split("<|endoftext|>")
    response = [
        (response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)
    ]
    history[-1] = response[0]
    return history

# Blocksの作成
with gr.Blocks() as demo:
    # UI
    chatbot = gr.Chatbot()
    msg = gr.Textbox()
    clear = gr.Button("Clear")

    # イベントリスナー
    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot, chatbot, chatbot
    )
    clear.click(lambda: None, None, chatbot, queue=False)

# 起動
demo.launch()

submit後に状態 history がどのように維持されるかに注目してください。user関数で入力 [msg, chatbot]・出力 [msg, chatbot] を処理した後、bot関数で入力 chatbot ・出力 chatbot を処理しています。

2. Reactiveインターフェイス

「Reactiveインターフェイス」は、データを自動的に更新または継続的にストリーミングするインターフェイスです。

2-1. live=True

live=True」を指定すると、ユーザー入力が変更されるとすぐ、Interfaceが自動的に更新されます。submitボタンはなくなります

import gradio as gr

# calculator関数
def calculator(num1, operation, num2):
    if operation == "add":
        return num1 + num2
    elif operation == "subtract":
        return num1 - num2
    elif operation == "multiply":
        return num1 * num2
    elif operation == "divide":
        return num1 / num2

# Interfaceの作成
demo = gr.Interface(
    calculator,
    [
        "number",
        gr.Radio(["add", "subtract", "multiply", "divide"]),
        "number"
    ],
    "number",
    live=True,
)

# 起動
demo.launch()

2-2. streaming=True

マイクモードのAudio や WebカメラモードのImage で「streaming=True」を指定すると、データが継続的にバックエンドに送信され、Interfaceが自動的に更新されます。

import gradio as gr
import numpy as np

# 画像反転の関数
def flip(im):
    return np.flipud(im)

# Interfaceの作成
demo = gr.Interface(
    flip,
    gr.Image(source="webcam", streaming=True),
    "image",
    live=True
)

# 起動
demo.launch()

3. 4種類のインターフェース

gradioには次の4種類のインタフェースがあります。

・標準デモ : 入力コンポーネントと出力コンポーネントがある
・出力専用デモ : 出力コンポーネントのみ
・入力専用デモ : 入力コンポーネントのみ
・統合デモ : 入出力コンポーネントのみ

3-1. 標準デモ

inputsとoutputsに別のコンポーネントを指定します。

import gradio as gr
import numpy as np


# セピア化の関数
def sepia(input_img):
    sepia_filter = np.array([
        [0.393, 0.769, 0.189],
        [0.349, 0.686, 0.168],
        [0.272, 0.534, 0.131]
    ])
    sepia_img = input_img.dot(sepia_filter.T)
    sepia_img /= sepia_img.max()
    return sepia_img

# Interfaceの作成
demo = gr.Interface(
    sepia,
    gr.Image(shape=(200, 200)),
    "image"
)

# 起動
demo.launch()

3-2. 出力専用デモ

inputsにNoneを指定します。

import time
import gradio as gr

# 関数
def fake_gan():
    time.sleep(1)
    images = [
            "https://images.unsplash.com/photo-1507003211169-0a1dd7228f2d?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=387&q=80",
            "https://images.unsplash.com/photo-1554151228-14d9def656e4?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=386&q=80",
            "https://images.unsplash.com/photo-1542909168-82c3e7fdca5c?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8aHVtYW4lMjBmYWNlfGVufDB8fDB8fA%3D%3D&w=1000&q=80",
    ]
    return images

# Interfaceの作成
demo = gr.Interface(
    fn=fake_gan,
    inputs=None,
    outputs=gr.Gallery(label="Generated Images").style(grid=[2]),
    title="FD-GAN",
    description="This is a fake demo of a GAN. In reality, the images are randomly chosen from Unsplash.",
)

# 起動
demo.launch()

3-3. 入力専用デモ

outputsにNoneを指定します。

import random
import string
import gradio as gr

# 画像を保存
def save_image_random_name(image):
    random_string = ''.join(random.choices(string.ascii_letters, k=20)) + '.png'
    image.save(random_string)
    print(f"Saved image to {random_string}!")

# Interfaceの作成
demo = gr.Interface(
    fn=save_image_random_name,
    inputs=gr.Image(type="pil"),
    outputs=None,
)

# 起動
demo.launch()

3-4. 統合デモ

inputsとoutputsに同じコンポーネントを指定します。

import gradio as gr
from transformers import pipeline

# パイプラインの準備
generator = pipeline('text-generation', model = 'gpt2')

# テキスト生成の関数
def generate_text(text_prompt):
    response = generator(text_prompt, max_length = 30, num_return_sequences=5)
    return response[0]['generated_text']

# テキストボックスの準備
textbox = gr.Textbox()

# Interfaceの作成
demo = gr.Interface(
    generate_text,
    textbox,
    textbox
)

# 起動
demo.launch()

次回



いいなと思ったら応援しよう!