RWKV-5-World-3BでLangChain?
RWKV-5-World-3BでLangChainのAgentが動作するか知りたかったので、
カスタムモデルを作りました。
from typing import Any, Dict, List, Mapping, Optional, Set
from huggingface_hub import hf_hub_download
from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS
from langchain.llms.base import LLM
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
class RWKV_5_WORLD(LLM, BaseModel):
# RWKV params ----------------------------
model_name = "RWKV-5-World-3B-v2-20231118-ctx16k"
strategy = "cuda fp16"
tokenizer = "rwkv_vocab_v20230424"
top_p = 0.3
top_k = 100 # top_k = 0 then ignore
alpha_frequency = 0.25
alpha_presence = 0.25
alpha_decay = 0.996 # gradually decay the penalty
token_ban = [] # ban the generation of some tokens
token_stop = [] # stop generation whenever you see any token here
chunk_len = 256 # split input into chunks to save VRAM
token_count = 200 # Maximum number of tokens to generate.
temperature = 1.0
# ----------------------------------------
client: Any = None #: :meta private:
pipeline: Any = None #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {
"verbose": self.verbose,
"model_name": self.model_name,
"strategy": self.strategy,
"tokenizer": self.tokenizer,
"top_p": self.top_p,
"top_k": self.top_k,
"alpha_frequency": self.alpha_frequency,
"alpha_presence": self.alpha_presence,
"alpha_decay": self.alpha_decay,
"token_ban": self.token_ban,
"token_stop": self.token_stop,
"chunk_len": self.chunk_len,
"token_count": self.token_count,
"temperature": self.temperature,
}
@staticmethod
def _rwkv_param_names() -> Set[str]:
"""Get the identifying parameters."""
return {
"verbose",
}
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
try:
model_name = values["model_name"]
model_path = hf_hub_download(
repo_id="BlinkDL/rwkv-5-world", filename=f"{model_name}.pth"
)
if model_path is None:
raise ValueError(
f"Model path could not be downloaded. Check the model_name."
)
strategy = values["strategy"]
tokenizer = values["tokenizer"]
values["client"] = RWKV(model=model_path, strategy=strategy)
values["pipeline"] = PIPELINE(values["client"], tokenizer)
except ImportError:
raise ImportError(
"Could not import rwkv python package. "
"Please install it with `pip install rwkv`."
)
return values
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
**self._default_params,
}
@property
def _llm_type(self) -> str:
"""Return the type of llm."""
return "rwkv"
def rwkv_generate(self, prompt: str) -> str:
args = PIPELINE_ARGS(
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
alpha_frequency=self.alpha_frequency,
alpha_presence=self.alpha_presence,
alpha_decay=self.alpha_decay,
token_ban=self.token_ban,
token_stop=self.token_stop, # Use stop if provided
chunk_len=self.chunk_len,
)
result = self.pipeline.generate(prompt, token_count=self.token_count, args=args)
return result
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
r"""RWKV generation
Args:
prompt: The prompt to pass into the model.
stop: A list of strings to stop generation when encountered.
Returns:
The string generated by the model.
Example:
.. code-block:: python
prompt = "Once upon a time, "
response = model(prompt, n_predict=55)
"""
text = self.rwkv_generate(prompt)
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text
こちらの記事を参考に呼び出してみました
import rwkv5
# Instructプロンプトの生成
def generate_prompt(instruction, input=None):
if input:
return f"""Instruction: {instruction}
Input: {input}
Response: """
else:
return f"""Question: {instruction}
Answer: """
llm = rwkv5.RWKV_5_WORLD()
prompt = generate_prompt(
"後藤ひとりが加入するバンドの名前は何ですか?",
"""
後藤ひとりはギターを愛する孤独な少女。
家では孤独でただ遊んでばかりの毎日だったが、ひょんなことから伊地知虹夏率いる「結束バンド」に加入することに。
人前で演奏することに不慣れな後藤は、立派なバンドマンになれるのか?
""",
)
result = llm(prompt)
print(result)
次は、Agentとして呼び出します。
import rwkv5
from transformers import pipeline
from langchain.tools import DuckDuckGoSearchRun
from langchain.agents import Tool, initialize_agent
from langchain.chains import LLMMathChain
from pydantic import BaseModel, Field
llm = rwkv5.RWKV_5_WORLD()
# pip install duckduckgo-search
search = DuckDuckGoSearchRun()
tools = [
Tool(
name="duckduckgo-search",
func=search.run,
description="useful for when you need to answer questions. You should ask targeted questions",
)
]
class CalculatorInput(BaseModel):
question: str = Field()
# pip install numexpr
llm_math_chain = LLMMathChain.from_llm(llm=llm, verbose=True)
tools.append(
Tool.from_function(
func=llm_math_chain.run,
name="Calculator",
description="useful for when you need to answer questions about math",
args_schema=CalculatorInput,
# coroutine= ... <- you can specify an async method if desired as well
)
)
agent = initialize_agent(
tools,
llm,
agent="zero-shot-react-description",
verbose=True,
handle_parsing_errors=True,
)
text = "What is the second highest mountain in Japan?"
output = agent.run(text)
output = output.split("\n")[0]
print(output)
text = "Who is Leo DiCaprio's girlfriend? What is her current age?"
output = agent.run(text)
output = output.split("\n")[0]
print(output)
text = "Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?"
output = agent.run(text)
output = output.split("\n")[0]
print(output)
Chatとして利用すると3Bモデルとは思えないほど優秀な応答をしますが、残念ながらAgentとして動作しないようです。
Agent以外にも何か上手な能力の引き出し方を模索したいですね。
この記事が気に入ったらサポートをしてみませんか?