RWKV-5-World-3BでLangChain?

RWKV-5-World-3BでLangChainのAgentが動作するか知りたかったので、
カスタムモデルを作りました。

from typing import Any, Dict, List, Mapping, Optional, Set
from huggingface_hub import hf_hub_download
from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS
from langchain.llms.base import LLM
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import BaseModel, Extra, root_validator


class RWKV_5_WORLD(LLM, BaseModel):
    # RWKV params ----------------------------
    model_name = "RWKV-5-World-3B-v2-20231118-ctx16k"
    strategy = "cuda fp16"
    tokenizer = "rwkv_vocab_v20230424"
    top_p = 0.3
    top_k = 100  # top_k = 0 then ignore
    alpha_frequency = 0.25
    alpha_presence = 0.25
    alpha_decay = 0.996  # gradually decay the penalty
    token_ban = []  # ban the generation of some tokens
    token_stop = []  # stop generation whenever you see any token here
    chunk_len = 256  # split input into chunks to save VRAM
    token_count = 200  # Maximum number of tokens to generate.
    temperature = 1.0
    # ----------------------------------------
    client: Any = None  #: :meta private:
    pipeline: Any = None  #: :meta private:

    class Config:
        """Configuration for this pydantic object."""

        extra = Extra.forbid

    @property
    def _default_params(self) -> Dict[str, Any]:
        """Get the identifying parameters."""
        return {
            "verbose": self.verbose,
            "model_name": self.model_name,
            "strategy": self.strategy,
            "tokenizer": self.tokenizer,
            "top_p": self.top_p,
            "top_k": self.top_k,
            "alpha_frequency": self.alpha_frequency,
            "alpha_presence": self.alpha_presence,
            "alpha_decay": self.alpha_decay,
            "token_ban": self.token_ban,
            "token_stop": self.token_stop,
            "chunk_len": self.chunk_len,
            "token_count": self.token_count,
            "temperature": self.temperature,
        }

    @staticmethod
    def _rwkv_param_names() -> Set[str]:
        """Get the identifying parameters."""
        return {
            "verbose",
        }

    @root_validator()
    def validate_environment(cls, values: Dict) -> Dict:
        try:
            model_name = values["model_name"]
            model_path = hf_hub_download(
                repo_id="BlinkDL/rwkv-5-world", filename=f"{model_name}.pth"
            )
            if model_path is None:
                raise ValueError(
                    f"Model path could not be downloaded. Check the model_name."
                )

            strategy = values["strategy"]
            tokenizer = values["tokenizer"]

            values["client"] = RWKV(model=model_path, strategy=strategy)
            values["pipeline"] = PIPELINE(values["client"], tokenizer)

        except ImportError:
            raise ImportError(
                "Could not import rwkv python package. "
                "Please install it with `pip install rwkv`."
            )
        return values

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters."""
        return {
            **self._default_params,
        }

    @property
    def _llm_type(self) -> str:
        """Return the type of llm."""
        return "rwkv"

    def rwkv_generate(self, prompt: str) -> str:
        args = PIPELINE_ARGS(
            temperature=self.temperature,
            top_p=self.top_p,
            top_k=self.top_k,
            alpha_frequency=self.alpha_frequency,
            alpha_presence=self.alpha_presence,
            alpha_decay=self.alpha_decay,
            token_ban=self.token_ban,
            token_stop=self.token_stop,  # Use stop if provided
            chunk_len=self.chunk_len,
        )

        result = self.pipeline.generate(prompt, token_count=self.token_count, args=args)
        return result

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        r"""RWKV generation

        Args:
            prompt: The prompt to pass into the model.
            stop: A list of strings to stop generation when encountered.

        Returns:
            The string generated by the model.

        Example:
            .. code-block:: python

                prompt = "Once upon a time, "
                response = model(prompt, n_predict=55)
        """
        text = self.rwkv_generate(prompt)

        if stop is not None:
            text = enforce_stop_tokens(text, stop)
        return text

こちらの記事を参考に呼び出してみました

import rwkv5


# Instructプロンプトの生成
def generate_prompt(instruction, input=None):
    if input:
        return f"""Instruction: {instruction}

Input: {input}

Response: """
    else:
        return f"""Question: {instruction}

Answer: """


llm = rwkv5.RWKV_5_WORLD()

prompt = generate_prompt(
    "後藤ひとりが加入するバンドの名前は何ですか?",
    """
    後藤ひとりはギターを愛する孤独な少女。
    家では孤独でただ遊んでばかりの毎日だったが、ひょんなことから伊地知虹夏率いる「結束バンド」に加入することに。
    人前で演奏することに不慣れな後藤は、立派なバンドマンになれるのか?
    """,
)

result = llm(prompt)
print(result)


後藤ひとりが加入するバンドの名前は「結束バンド」です。

次は、Agentとして呼び出します。

import rwkv5
from transformers import pipeline
from langchain.tools import DuckDuckGoSearchRun
from langchain.agents import Tool, initialize_agent
from langchain.chains import LLMMathChain
from pydantic import BaseModel, Field

llm = rwkv5.RWKV_5_WORLD()

# pip install duckduckgo-search
search = DuckDuckGoSearchRun()
tools = [
    Tool(
        name="duckduckgo-search",
        func=search.run,
        description="useful for when you need to answer questions. You should ask targeted questions",
    )
]


class CalculatorInput(BaseModel):
    question: str = Field()


# pip install numexpr
llm_math_chain = LLMMathChain.from_llm(llm=llm, verbose=True)
tools.append(
    Tool.from_function(
        func=llm_math_chain.run,
        name="Calculator",
        description="useful for when you need to answer questions about math",
        args_schema=CalculatorInput,
        # coroutine= ... <- you can specify an async method if desired as well
    )
)

agent = initialize_agent(
    tools,
    llm,
    agent="zero-shot-react-description",
    verbose=True,
    handle_parsing_errors=True,
)

text = "What is the second highest mountain in Japan?"
output = agent.run(text)
output = output.split("\n")[0]
print(output)

text = "Who is Leo DiCaprio's girlfriend? What is her current age?"
output = agent.run(text)
output = output.split("\n")[0]
print(output)

text = "Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?"
output = agent.run(text)
output = output.split("\n")[0]
print(output)
Agentとしては動作しない


Chatとして利用すると3Bモデルとは思えないほど優秀な応答をしますが、残念ながらAgentとして動作しないようです。
Agent以外にも何か上手な能力の引き出し方を模索したいですね。

この記事が気に入ったらサポートをしてみませんか?