
Google Colab で CodeGemma を試す
「Google Colab」で「CodeGemma」を試したので、まとめました。
1. CodeGemma
「CodeGemma」は、コードタスク用のモデルです。次の3種類のモデルが提供されています。
・google/codegemma-2b : 高速コード補完用
・google/codegemma-7b : コード補完とコード生成用
・google/codegemma-7b-it : コード生成とチャットと指示用

2. コード補完
Colabでのコード補完の手順は、次のとおりです。
(1) パッケージのインストール。
# パッケージのインストール
!pip install transformers accelerate
(2) 「HuggingFace」からAPIキーを取得し、Colabのシークレットマネージャーの「HF_TOKEN」に登録。

(3) トークナイザーとモデルの準備。
今回は、「google/codegemma-2b」を使います。
from transformers import GemmaTokenizer, AutoModelForCausalLM
# トークナイザーとモデルの準備
tokenizer = GemmaTokenizer.from_pretrained(
"google/codegemma-2b"
)
model = AutoModelForCausalLM.from_pretrained(
"google/codegemma-2b",
device_map="auto",
torch_dtype="auto",
)
(4) 推論の実行。
# プロンプトの準備
prompt = '''\
<|fim_prefix|>import datetime
def calculate_age(birth_year):
"""Calculates a person's age based on their birth year."""
current_year = datetime.date.today().year
<|fim_suffix|>
return age<|fim_middle|>\
'''
# 推論の実行
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
prompt_len = inputs["input_ids"].shape[-1]
outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0][prompt_len:]))
age = current_year - birth_year<|file_separator|><eos>
プロンプトで使用するスペシャルトークンは、次のとおりです。
・<|fim_prefix|> : 補完前のコンテキストの先頭に配置
・<|fim_suffix|> : サフィックスの前に配置。コード補完の生成場所
・<|fim_middle|> : モデルに生成を促す場所に配置
・<|file_separator|> : 複数ファイルのセパレータ
3. コード生成
(1) 推論の実行。
# プロンプトの準備
# n番目のフィボナッチ数を計算するPython関数を書いてください。
prompt = "Write me a Python function to calculate the nth fibonacci number."
# 推論の実行
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
prompt_len = inputs["input_ids"].shape[-1]
outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0][prompt_len:]))
py
<|fim_prefix|><|fim_suffix|><|fim_middle|>def fibonacci(n):
▁▁▁▁if n == 0:
▁▁▁▁▁▁▁▁return 0
▁▁▁▁elif n == 1:
▁▁▁▁▁▁▁▁return 1
▁▁▁▁else:
▁▁▁▁▁▁▁▁return fibonacci(n-1) + fibonacci(n-2)
▁▁▁▁
n = int(input("Enter the nth term: "))
print("The nth term of the Fibonacci series is:", fibonacci(n))
<|file_separator|><eos>