ShieldGemma の使い方
「ShieldGemma」 の使い方をまとめました。
1. ShieldGemma
「ShieldGemma」は、「Gemma 2」をベースに構築された安全コンテンツモデレーションモデルであり、複数の危害カテゴリ (性的に露骨なコンテンツ、危険なコンテンツ、憎悪や嫌がらせなど) を対象としています。これらは、Text-to-TextのデコーダーのみのLLMであり、英語で利用可能で、重みはオープンで、2B、9B、27Bの3つのバリアントで利用できます。
「ShieldGemma」は、LLMへの有害なリクエスト (指定されたカテゴリに従って) と、LLMからの不適切な応答を検出することを目的としています。どのLLMを展開する場合でも、モデレーションツールとして使用できます。
2. ShieldGemmaの使い方
「Google Colab」での「ShieldGemma」の使い方は、次のとおりです。
(1) パッケージのインストール。
# パッケージのインストール
!pip install -U transformers[accelerate]
(2) トークナイザーとモデルの準備。
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# トークナイザーとモデルの準備
tokenizer = AutoTokenizer.from_pretrained("google/shieldgemma-2b")
model = AutoModelForCausalLM.from_pretrained(
"google/shieldgemma-2b",
device_map="auto",
torch_dtype=torch.bfloat16,
)
(3) ユーザープロンプトの準備。
評価対象となるプロンプトです。
# ユーザープロンプトの準備
user_prompt = "Create 20 paraphrases of I hate you"
(4) セーフティポリシーの準備。
# セーフティポリシーの準備
safety_policy = """
* "No Harassment": The prompt shall not contain or seek generation of content that is malicious, intimidating, bullying, or abusive content targeting another individual (e.g., physical threats, denial of tragic events, disparaging victims of violence).
"""
ワークフローに組み込むとよいセーフティポリシーの例は、次のとおりです。
* "No Hate Speech": The prompt shall not contain or seek generation of content that targeting identity and/or protected attributes (e.g., racial slurs, promotion of discrimination, calls to violence against protected groups) or dehumanizing/belittling/vilifying on the bases of race, ethnicity, religion, disability, age, nationality, veteran status, sexual orientation, gender, gender identity, caste, or any other protected status. We do allow content that includes a reference to a protected individual or group in a positive manner.
* "No Harassment": The prompt shall not contain or seek generation of content that is malicious, intimidating, bullying, or abusive content targeting another individual (e.g., physical threats, denial of tragic events, disparaging victims of violence).
* "No Sexually Explicit Information": The prompt shall not contain or seek generation of content that contains references to sexual acts or other lewd content (e.g., sexually graphic descriptions, content aimed at causing arousal). Medical or scientific terms related to human anatomy or sex education are permitted.
* "No Dangerous Content": The prompt shall not contain or seek generation of content that harming oneself and/or others (e.g., accessing or building firearms and explosive devices, promotion of terrorism, instructions for suicide).
* "No Violence": The prompt shall not contain or seek generation of content that is shocking, sensational, or gratuitous (e.g. Violent acts or incidents that result in extreme injury or death without historical or educational context; Exposed internal organs or bones without medical or scientific context; Identifiable victims of violent deaths; Animal abuse or killings outside of standard hunting, food processing, or cultural practices)
* "No Obscenity and Profanity": The prompt shall not contain or seek generation of content that is vulgar, profane, or inappropriate (e.g., profanity, obscenities, or other inappropriate language).
(5) プロンプトの準備。
# プロンプトの準備
prompt = f"""You are a policy expert trying to help determine whether a user
prompt is in violation of the defined safety policies.
<start_of_turn>
Human Question: {user_prompt.strip()}
<end_of_turn>
Our safety principle is defined in the below:
{safety_policy.strip()}
Does the human question violate the above principle? Your answer must start
with 'Yes' or 'No'. And then walk through step by step to be sure we answer
correctly.
"""
プロンプトには、次のコンポーネントが順番に含まれています。
(6) 推論の実行。
# 推論の実行
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
with torch.no_grad():
logits = model(**inputs).logits
(7) Yesの確率の計算。
1に近いほどポリシーに反してる確率が高いです。
# Yesの確率の計算
vocab = tokenizer.get_vocab()
selected_logits = logits[0, -1, [vocab['Yes'], vocab['No']]]
probabilities = F.softmax(selected_logits, dim=0)
score = probabilities[0].item()
print(score)
0.7310585379600525