シュッといい感じの評価してくれるエージェントの Adala を触ってみた
ちゃっす(/・ω・)/
Label Studio さんのブログで紹介されてた Adala というのが気になったので触ってみたぞい☆
GitHub はこちら(/・ω・)/
まぁ実際に動かした結果を見た方が早いので Let's do it !!
Google Colab で動かしたよ(/・ω・)/
まずはいんすとーる
!pip install adala
今回は二値分類のタスクを解かせてみるぞい☆
ということで教師データを用意するでござんす
import pandas as pd
df = pd.DataFrame([
["寿司", "Traditional"],
["ラーメン", "Modern"],
["天ぷら", "Traditional"],
["ハンバーガー", "Modern"],
["刺身", "Traditional"],
["カレーライス", "Modern"],
["おでん", "Traditional"],
["パスタ", "Modern"],
["餃子", "Modern"],
["納豆", "Traditional"],
["うなぎの蒲焼", "Traditional"],
], columns=["text", "ground_truth"])
df
んで、dataset としてぶち込む
from adala.datasets import DataFrameDataset
dataset = DataFrameDataset(df=df)
そしてエージェントの設定でござる
from adala.agents import Agent
from adala.environments import BasicEnvironment
from adala.skills import ClassificationSkill
from adala.runtimes import OpenAIRuntime
from rich import print
import os
os.environ["OPENAI_API_KEY"] = "sk-"
agent = Agent(
# define the agent's labeling skill that should classify text onto 2 categories
skills=ClassificationSkill(
name='traditional_or_modern_detection',
description='Understanding traditional dish and modern dish statements from text.',
instructions='Classify a dish name as either expressing "Traditional" or "Modern" statements.',
labels=['Traditional', 'Modern'],
input_data_field='text'
),
# basic environment extracts ground truth signal from the input records
environment=BasicEnvironment(
ground_truth_dataset=dataset,
ground_truth_column='ground_truth'
),
runtimes = {
# You can specify your OPENAI API KEY here via `OpenAIRuntime(..., api_key='your-api-key')`
'openai': OpenAIRuntime(model='gpt-3.5-turbo-instruct'),
'openai-gpt3': OpenAIRuntime(model='gpt-3.5-turbo'),
'openai-gpt4': OpenAIRuntime(model='gpt-4'),
},
default_runtime='openai',
# NOTE! If you don't have an access to gpt4 - replace it with "openai-gpt3"
default_teacher_runtime='openai-gpt4'
)
print(agent)
print 結果はこちらよん
Agent Instance
Environment: BasicEnvironment
Skills: traditional_or_modern_detection
Runtimes: openai, openai-gpt3, openai-gpt4
Default Runtime: openai
Default Teacher Runtime: openai-gpt4
では学びなさい(/・ω・)/
learning_experience = agent.learn(learning_iterations=3, accuracy_threshold=0.95)
で、ここからがミソであるが Option として指定した Interation だけ評価、分析、改善を繰り返すのである(/・ω・)/
評価(五件しか表示されないけど11件の教師データを見てるぞ☆)
100%|██████████| 11/11 [00:02<00:00, 3.93it/s]
=> Iteration #0: Comparing to ground truth, analyzing and improving ...
Comparing predictions to ground truth data ...
text ground_truth traditional_or_modern_de… score ground_truth__x__traditi…
─────────────────────────────────────────────────────────────────────────────────────────────────────────────────
寿司 Traditional Traditional {'Traditional': True
-0.14485012, 'Modern':
-2.003607}
ラーメン Modern Traditional {'Traditional': False
-0.12320776999999997,
'Modern': -2.1548545}
天ぷら Traditional Traditional {'Traditional': True
-0.02719422399999996,
'Modern': -3.6183197}
ハンバーガー Modern Modern {'Traditional': True
-4.058615, 'Modern':
-0.017423895999999977}
刺身 Traditional Traditional {'Traditional': True
-0.4935383199999999,
'Modern': -0.9427952}
分析
Analyze evaluation experience ...
100%|██████████| 3/3 [00:00<00:00, 157.31it/s]
100%|██████████| 3/3 [00:15<00:00, 5.02s/it]
Number of errors: 3
Accuracy = 72.73%
Improve "traditional_or_modern_detection" skill based on analysis ...
Updated instructions for skill "traditional_or_modern_detection":
Based on the dish name provided, determine whether it represents a "Traditional" or "Modern" culinary style.
Examples:
Input: ラーメン
Output: Modern
Input: 餃子
Output: Modern
Input: カレーライス
Output: Modern
改善結果を適用
Re-apply traditional_or_modern_detection skill to dataset ...
100%|██████████| 11/11 [00:03<00:00, 3.66it/s]
Take2
=> Iteration #1: Comparing to ground truth, analyzing and improving ...
Comparing predictions to ground truth data ...
text ground_truth traditional_or_modern_de… score ground_truth__x__traditi…
─────────────────────────────────────────────────────────────────────────────────────────────────────────────────
寿司 Traditional Traditional {'Traditional': True
-0.00504399839999998,
'Modern': -5.292077}
ラーメン Modern Modern {'Traditional': True
-4.4393187, 'Modern':
-0.011874314000000028}
天ぷら Traditional Traditional {'Traditional': True
-0.009848873999999952,
'Modern': -4.62532}
ハンバーガー Modern Modern {'Traditional': True
-5.9589763, 'Modern':
-0.0025858853000000357}
刺身 Traditional Traditional {'Traditional': True
-0.0013666658000000445,
'Modern': -6.5960474}
Analyze evaluation experience ...
Number of errors: 0
Accuracy = 100.00%
Accuracy threshold reached (1.0 >= 0.95)
Train is done!
ここで Accuracy が 100 になったので終了!!
ちなみに、推論は Runtimes として Option に設定したものは使用されるけれど分析、改善は Teacher Runtime が実施するぞ☆
デフォルトは GPT4 だぞ☆
でまぁ結果も見れるざんす(/・ω・)/
learning_experience.predictions
ではテストしマッスルか
test_df = pd.DataFrame([
"カリフォルニアロール",
"うどん",
"アイスクリーム",
"おにぎり"
], columns=['text'])
test_df
いざ(/・ω・)/
result = agent.apply_skills(test_df)
result.predictions
結果
んでまぁ今回はクイックスタートをちょちょっといじってシュッと試しただけだけなので教師データ(Ground Truth) は最初に設定したけれど、必要に応じて request_feedback を実装すれば別のデータソース(例えば Label Studio)の結果を取り込んで実行できると思いますわ~(/・ω・)/
@abstractmethod
def request_feedback(self, skill: BaseSkill, experience: ShortTermMemory):
"""Request user feedback using predictions and update internal ground truth set."""
ちなみに今使えるスキルはこんな感じみたい(/・ω・)/
👉 Available skills
ClassificationSkill – Classify text into a set of predefined labels.
ClassificationSkillWithCoT – Classify text into a set of predefined labels, using Chain-of-Thoughts reasoning.
SummarizationSkill – Summarize text into a shorter text.
QuestionAnsweringSkill – Answer questions based on a given context.
TranslationSkill – Translate text from one language to another.
TextGenerationSkill – Generate text based on a given prompt.
まだできたばっかりなので今後に期待(/・ω・)/
というわけでおしまい。