
深層強化学習AIをClusterScriptで動かす
はじめに
やっほ~ さなだよ
Cluster Creator #2 Advent Calendar 2024 15日目!
この記事では、深層強化学習でAIにゲームの仕組みを学んでもらって、ClusterScript上で動かす方法を紹介するよ!
背景知識
具体的な手順の前に、深層強化学習やその前提となる強化学習の仕組みや考え方について説明していきます
強化学習って何?
プログラムを組む際は通常、if文などを用いて遭遇し得る全てのケースに対し指示を与えます。


しかし、プログラムによっては遭遇し得る全てのケースを明示することが困難なこともあります。例えば、画像の中の物体を判別するようなプログラムを作るとき、全ての画像パターンの条件分岐を設定するのは限りなく困難でしょう。ただ、一定以上の量の画像とそれに対する回答を用意すれば、その傾向から未知の画像に対しても答えを予測することはできます。こうした手法はAIの中で「教師あり学習」と呼ばれ、一部のケースに対する指示を学習データとして与え、未知のケースに対して適切な行動を推論することができます。

しかし、教師あり学習にも問題がないわけではありません。学習データは想定される条件を全てではないにせよ満遍なく網羅している必要がありますし、データ量を確保するには相応のコストがかかります。一方で、プログラムが目指すべき特定の条件のみを設定し、各条件における行動はすべて試行錯誤を通してプログラム自身に学習してもらう手法も考案されました。これが「強化学習」です。

強化学習はスコアや勝敗から明確に報酬を設定できるゲーム用Botと相性が良く、深層強化学習によって作られたAlphaGoが2016年に囲碁棋士の王者を破っています。
強化学習の手順
では、限られた条件からどうやって個々のケースにおける行動を学習していくのでしょうか。強化学習のアルゴリズムはいくつか提案されているのですが、ここではQ-learningと呼ばれるアルゴリズムについて解説していきます。
Q-learningでは、全ての状態における全ての行動の価値を記録したQ-tableを作成します。ここで言う「価値」とは、その条件でその行動を選択することで最終的にどれだけ報酬を得る見込みがあるのかを計算したものです。
Q-tableは以下の式で更新します。αは学習速度, γは直接の報酬に対し報酬の見込みにかける割引率を示すハイパーパラメータです。

では、実際に強化学習がどのように進んでいくかを見ていきましょう。あるマス目からゴールマスにたどり着くことを目指すゲームを例に挙げます。マスの位置を状態、進む方向を行動として扱い、各マス目において各行動がどれだけゴールに繋がるかをQ-tableに記録します。更新式のαは0.1、γは0.9として計算します。
まず最初は、Q-tableの値はすべて0であり、状態①において上下左右どの方向に進むことも同じ価値を持っています。この場合進行方向はランダムに決定されますが、今回はゴールに向かう右方向がたまたま選択されたとしましょう。

状態②における行動の価値はすべて0なので、状態①で右に進むことに価値は加えられません。次にQ-tableに基づいて進む方向を決めるのですが、状態②において上下左右いずれの価値も0なので、行動は再びランダムに決定されます。では、また右方向が選択されたとしましょう。

すると、あてずっぽうではあるもののゴールにたどり着きました。報酬を獲得し、状態②で右に進む行動に価値を加算します。では、初期位置に戻り探索を続行しましょう。

さて2周目になりますが、状態①に関しては1周目と同様行動の価値がすべて0です。なので行動はランダムに決定されるのですが、またまた右方向が選択されたとしましょう。

状態②に移ると、報酬に繋がることを示す価値が記録されていました。すると、以前の行動にも価値がありそうなので、状態①で右に進む行動に価値を加算します。そして、状態②で最も高い価値を持つ右方向への移動を選択します。

するとQ-tableの記録通り、ゴールにたどり着きました。1周目と同様、状態②で右に進む行動に価値を加算し、初期位置に戻って探索を続行します。

以上の流れを終了条件(最大行動数など)を満たすまで繰り返し、Q-tableを更新し続けることで、各状態における最適な行動を学習していきます。今回の例では初期のランダム探索で右に進み続けた場合を紹介しましたが、実際は上下左右様々な方向に進み、様々なマス目からゴールに向かう方向を学習していきます。
実際の探索の流れは以下のワールドで見ることができます。障害物やゴールの設定を変えることもできるので、是非色々試してみてください。
深層強化学習って何?
さて、Q-learningでは全ての状態における全ての行動の価値をQ-tableに記録・更新することで学習しました。しかし環境によっては、全ての状態における全ての行動という膨大なパターンを学習するのが困難で、網羅できたとしても記録するデータ量が膨大になってしまいます。例のような単純なゴール探索であっても、位置情報を連続量として扱うだけでもその情報量は肥大化します。
そこで、各状態に基づいて各行動の価値を返すというQ-tableの役割をニューラルネットワークで代用する手法が考案されました。ニューラルネットワークを用いた学習の中でもその層が深いものを深層強化学習と呼びます。
※「深層」という言葉自体に大して深い意味はないっぽい? 実態を示すなら「ニューラルネットワークを用いた強化学習」かな。まぁ「深層」とか「ディープラーニング」って響きが良いんですよね、なんかバズりやすそうだし。
ニューラルネットワークについて
ニューラルネットワークの基本構造は単純で、入力に重みと呼ばれる数値wをかけた上で足し合わせ、そこにバイアスと呼ばれる数値bを加えるだけです。

行列を使って数式っぽく書くとこうなります

x,y座標に対し適切な方向の価値が大きくなるようにwやbを更新すれば、Q-tableと同じように各状態における最適な行動を選択することができます。
誤差関数の設定
では、どのようにw,bを更新していくのでしょうか? そのためにはまず、現在のw,bがどれだけ適切であるかを示す誤差関数を用意します。

この誤差関数では、得た報酬, 現状態における行動の最大価値と、1つ前の状態で取った行動の価値を比較しています。そして、現状態に報酬や大きな価値のある行動があれば、1つ前の状態で取った行動の価値も大きくなるべきです。つまり、この誤差関数が小さくなるようにw,bを更新することで、各状態において適切な価値を返すようなニューラルネットワークを学習できます。
※現状態における行動の価値をQ*と表記しているのは、実際学習を行うとき、現状態と1つ前の状態で得られる価値を計算するネットワークを分けるためです。実際更新されるのはQのみで、Q*は数ステップごとにQからコピーされます。詳細は下記の記事の「Fixed Target Q-Network」を参照
Deep Q-Network(DQN)ことはじめ #強化学習 - Qiita
パラメータの更新手法
誤差関数を小さくするようなw,bの更新は一般的なニューラルネットワークと同じ手順です。良質な解説が既にたくさん存在するのでそれらを紹介しておきます。
大学生のインフラ 予備ノリ
StatQuest
1つ1つの値の更新の仕方を丁寧に説明していて個人的に一番分かりやすかった。英語だけど字幕と図解が豊富だから意外とハードル高くなさそう…?
背景知識 まとめ
強化学習は、報酬を与える条件を設定し、プログラムが報酬を得るための行動を繰り返し探索することで、各状態における行動を自ら学習する手法
学習は、各状態に対する各行動の価値を適切に更新することで行われる。
通常の強化学習では全ての状態に対する全ての行動の価値をQ-tableに記録するが、深層強化学習では状態量に対するニューラルネットワークの応答を各行動の価値として扱う
ニューラルネットワークは入力の重み付き和にバイアスを足して出力する
Unityでの事前学習
さて、ここからは深層強化学習するAIを作成する手順を説明していきます。
ニューラルネットワークの学習には時間がかかるので、Unity上で事前に学習したネットワークを作成し、cluster上では学習済みモデルに基づく行動選択のみを行うようにします。
ml-agent
Unityには、深層強化学習用を行うためのライブラリ「ml-agent」が配布されています。これは、ゲーム環境と行動を学習するエージェントを用意し、報酬を獲得する条件、ネットワークモデルや学習パラメータの設定を行うだけで深層強化学習を実行しモデルを作成してくれます。
ml-agentの詳細なセットアップ方法や使い方については既に優れた解説書があるので割愛して、概要だけを説明していきます。
布留川 英一 様による解説書籍
著者様によるnote簡易な解説記事
ゲーム環境
ml-agentはCCK環境ではなく、通常のUnityのC#プログラムを想定して作成されているため、ゲーム環境とエージェントをC#とCCKの両方で同じものを作成します。
まずはnote記事の例にある、Cubeに触れればクリアのゲームを例に説明します。Sphereが自分の位置・速度・Cubeの位置をもとに前後左右に加速します。そして、床から落ちないように最短でCubeに触れるための動きを学習していきます。
学習環境

ml-agentのスクリプト(C#)
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Policies;
// RollerAgent
public class RollerDiscrete : Agent
{
public Transform target; // TargetのTransform
Rigidbody rBody; // RollerAgentのRigidBody
// 初期化
public override void Initialize()
{
this.rBody = GetComponent<Rigidbody>();
}
// エピソード開始
public override void OnEpisodeBegin()
{
// RollerAgentが床から落下していたら位置と速度をリセット
if (this.transform.localPosition.y < 0)
{
this.rBody.angularVelocity = Vector3.zero;
this.rBody.velocity = Vector3.zero;
this.transform.localPosition = new Vector3(0.0f, 0.5f, 0.0f);
}
// Targetをランダムにスポーン
target.localPosition = new Vector3(
Random.value*8-4, 0.5f, Random.value*8-4);
}
// 状態取得
public override void CollectObservations(VectorSensor sensor)
{
sensor.AddObservation(target.localPosition.x); //TargetのX座標
sensor.AddObservation(target.localPosition.z); //TargetのZ座標
sensor.AddObservation(this.transform.localPosition.x); //RollerAgentのX座標
sensor.AddObservation(this.transform.localPosition.z); //RollerAgentのZ座標
sensor.AddObservation(rBody.velocity.x); // RollerAgentのX速度
sensor.AddObservation(rBody.velocity.z); // RollerAgentのZ速度
}
// 行動実行
public override void OnActionReceived(ActionBuffers actionBuffers)
{
// RollerAgentに力を加える
Vector3 controlSignal = Vector3.zero;
int action = actionBuffers.DiscreteActions[0];
if(action == 1) controlSignal.z = 1.0f;
if(action == 2) controlSignal.z = -1.0f;
if(action == 3) controlSignal.x = -1.0f;
if(action == 4) controlSignal.x = 1.0f;
rBody.AddForce(controlSignal * 5);
// RollerAgentがTargetの位置にたどりついた時
float distanceToTarget = Vector3.Distance(
this.transform.localPosition, target.localPosition);
if (distanceToTarget < 1.42f)
{
AddReward(1.0f);
EndEpisode();
}
// RollerAgentが床から落下した時
if (this.transform.localPosition.y < 0)
{
EndEpisode();
}
}
}
学習過程
学習を実行すると以下のようになります。最初はランダムな探索をしており、なかなかCubeに当たりませんが、偶然衝突し報酬を得ることを繰り返すうちに、効率良くCubeに向かって動くようになっていきます。
— さな (@ponpopon37) December 15, 2024
学習結果
学習を進めると、モデルを記録したOnnxファイル生成されます。Sphereに登録すると、以下のように学習した動きを見ることができます。
— さな (@ponpopon37) December 15, 2024
ClusterScriptでの実装
さて、いよいよ本題です。エージェントの行動をClusterScriptで実装していきましょう。
アーキテクチャ確認
まずは、作成されたネットワークの構造を確認しましょう。NETRONというサイトにOnnxファイルをアップロードするとネットワークの構造・各レイヤーのパラメータを見ることができます。

パラメータ取得
各レイヤーのパラメータを取得します
例えば、最初の線形結合の重みパラメータは、入力が6つ, ノード数が32なので、32×6の行列となっています

流石に全部のパラメータをコピペするのは面倒なので、onnxファイルから直接取得するPythonコードを用意しました
import onnx
import numpy as np
def extract_parameters(onnx_file_path):
# ONNXファイルを読み込む
model = onnx.load(onnx_file_path)
# GraphProtoオブジェクトを取得する
graph = model.graph
parameters = {}
# ノードをイテレートして、初期化されたテンソルを取得する
for initializer in graph.initializer:
# パラメータ名を取得
param_name = initializer.name
# パラメータの値を取得
param_value = initializer.raw_data # バイナリデータ
# パラメータのデータ型を取得
param_dtype = initializer.data_type
# バイナリデータをNumPy配列に変換してshapeを取得
param_value_array = np.frombuffer(param_value, dtype=np.dtype(onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[param_dtype]))
param_shape = param_value_array.shape
# 取得した情報を辞書に格納
parameters[param_name] = {
'value': param_value_array,
'dtype': param_dtype,
'shape': param_shape
}
return parameters
# ONNXファイルのパスを指定してパラメータを取り出す
onnx_file_path = 'ml-agents\RollerDiscrete.onnx'
parameters = extract_parameters(onnx_file_path)
# print(parameters.keys())
#取り出したパラメータを表示
names = [
'network_body._body_endoder.seq_layers.0.weight',
'network_body._body_endoder.seq_layers.0.bias',
'network_body._body_endoder.seq_layers.2.weight',
'network_body._body_endoder.seq_layers.2.bias',
'action_model._discrete_distribution.branches.0.weight',
'action_model._discrete_distribution.branches.0.bias'
]
names_ = ['layer_0_weight','layer_0_bias','layer_1_weight','layer_1_bias','layer_2_weight','layer_2_bias']
shape_0 = np.array([[parameters[names[i*2+1]]['shape'][0]]*2 for i in range(len(names)//2)]).flatten()
i=0
# print(parameters['49']['value'])
for i in range(len(names)):
#print(f"{names[i]}")
array = np.round(parameters[names[i]]['value'], 4)
array = array.reshape(shape_0[i],-1)
#print(array.shape)
if len(array.shape) == 1:
array = np.array(array, dtype='<U')
array = array.tolist()
array = [float(i) for i in array]
else:
array = np.array(array, dtype='<U')
array = array.tolist()
array = [[float(j) for j in i] for i in array]
text = f"const {names_[i]} = {array};"
text = text.replace(', ',',')
print(text)
i+=1
出力
const layer_0_weight = [[-0.096,1.4366,0.3164,0.0943,0.3078,-0.5185],[0.2091,-0.7532,-0.3272,0.2045,0.2063,0.1847],[-0.5981,-0.1219,0.5163,0.659,-0.6621,-0.1017],[-0.1061,-0.1457,-0.0041,0.081,0.1856,1.023],[-0.3295,0.7966,0.0499,-0.1411,-0.1873,0.054],[0.3258,-0.2165,-0.047,-0.1441,0.5739,-0.1579],[0.0392,0.0692,-0.4151,-0.1111,-0.35,-0.0861],[-0.3587,0.0875,0.7904,-0.7204,0.6337,0.0169],[-0.3514,0.5545,0.5793,-0.0334,0.1557,0.3505],[-0.0318,0.1515,0.0394,-0.5405,-0.2361,-0.5503],[-0.159,0.2712,0.3241,-0.3505,-0.1752,-0.5385],[0.3229,-0.1738,0.5206,0.4438,0.1656,0.0897],[-0.3811,-0.186,-0.016,0.4245,0.1119,0.1004],[-0.2527,-0.3572,-0.2366,-0.0064,-0.0945,-0.0859],[0.3685,0.5748,0.2209,0.6606,-0.0312,0.4726],[0.0173,-0.379,0.5886,0.1019,0.7983,0.3637],[-0.3361,-0.041,-0.4772,0.3733,0.5477,-0.0432],[0.6798,-0.2952,-0.6016,0.3443,0.1831,0.4309],[0.7618,-0.0944,-0.3625,-0.2873,0.3701,-0.1353],[-0.7142,0.3816,0.5489,0.192,0.8652,-0.2168],[1.0374,-0.0542,0.0719,-0.6429,-0.0716,0.0679],[0.4568,0.6782,-0.9494,-0.2203,-0.6093,-0.1763],[-0.8011,0.3173,-0.9123,0.9854,-0.0913,0.2167],[-0.0039,0.1403,0.3662,-0.3699,0.3986,0.007],[0.0174,0.8514,-0.0344,0.1451,-0.1854,0.0102],[-0.4555,-0.1865,0.2999,-0.263,0.3953,-0.6235],[0.645,0.6281,-0.26,0.1899,-0.4517,-0.2174],[0.337,-0.372,-0.6275,-0.3325,0.0715,0.51],[-1.1203,0.0832,0.3684,0.3314,0.0568,0.1255],[0.9006,0.0772,0.3145,-0.0318,-0.552,0.1631],[0.3902,0.4222,0.0952,0.8058,-0.2516,0.7941],[0.3756,0.3196,0.0386,-1.003,0.6501,-0.0623]];
const layer_0_bias = [[0.1587],[0.1779],[0.164],[0.2506],[0.1331],[0.0311],[0.0634],[0.2075],[0.1427],[0.3085],[0.0804],[0.1392],[0.1441],[0.2494],[0.1987],[0.1022],[0.0852],[0.2213],[0.1957],[0.0696],[0.0458],[0.2747],[0.2903],[0.1094],[0.1427],[0.1575],[0.121],[0.2342],[0.0943],[0.1853],[0.103],[0.0101]];
const layer_1_weight = [[0.1669,0.0276,0.5035,0.0557,0.0926,-0.1641,-0.1214,0.1229,0.1708,0.13,-0.1748,-0.1189,0.1365,0.4358,-0.0092,0.0806,-0.1391,-0.2066,-0.1508,0.0935,-0.1517,-0.1262,-0.2469,0.318,0.0634,0.2602,0.08,-0.3802,0.5126,0.062,0.0155,0.0419],[0.1247,0.4001,0.1158,0.1577,-0.0534,-0.0473,-0.4249,0.0848,0.1261,-0.4173,-0.2411,0.2891,0.2715,-0.1284,0.2004,0.2281,0.3338,-0.3038,-0.0526,-0.1464,-0.0146,-0.5022,0.1603,-0.0213,-0.0153,-0.0242,-0.0686,0.178,0.2717,0.2066,0.3705,-0.1121],[-0.0069,0.167,0.2708,0.3153,-0.0191,0.0087,0.1732,-0.4079,-0.4019,-0.0944,-0.1592,0.2659,0.2113,0.4688,0.1908,-0.092,0.248,0.2726,0.191,-0.2792,0.0591,0.0797,0.0049,-0.1773,-0.1406,-0.1034,-0.0057,0.548,0.0612,0.165,-0.1131,-0.2897],[0.3031,-0.0019,0.0941,-0.0971,0.1445,-0.0063,0.2546,-0.355,-0.0769,0.573,-0.1048,-0.0665,-0.4332,0.5181,0.1085,-0.4082,-0.0102,0.0145,0.2679,-0.1304,0.2112,0.1318,0.2196,0.168,0.0979,-0.1408,0.1274,0.3097,-0.3234,0.2259,0.0952,0.0677],[0.3853,0.1171,-0.0694,-0.318,-0.0029,0.054,0.0561,0.132,0.571,0.4079,-0.1027,0.0146,0.0603,0.2326,-0.0844,-0.066,0.1495,-0.3413,0.251,-0.1884,0.034,0.1129,-0.0153,-0.0615,0.1749,0.2181,-0.1408,-0.2233,0.2176,0.044,-0.1731,0.2497],[0.1553,0.0812,-0.3518,-0.4324,-0.3304,0.0925,-0.0872,0.2376,0.1261,-0.2135,-0.0948,0.1945,0.0732,-0.1342,0.1721,0.2871,0.1913,0.0835,-0.1813,0.1544,0.0607,-0.0682,-0.2363,-0.2354,-0.1679,-0.0269,0.0185,-0.137,0.0697,0.1668,0.0288,-0.1477],[-0.1348,-0.1722,0.0227,-0.045,0.2626,0.0387,-0.3373,0.2377,0.1773,0.1732,-0.0138,0.181,-0.0436,0.2557,0.0561,-0.1695,0.1276,-0.3692,-0.2382,0.383,0.1573,-0.0963,0.1446,0.0969,0.279,0.3957,-0.1427,-0.261,0.005,0.0237,0.0167,-0.2335],[0.3002,-0.0024,-0.5242,-0.0735,-0.1611,0.0601,-0.006,0.1484,0.1416,0.4736,-0.0044,-0.0495,-0.1662,0.2498,0.0928,-0.3255,-0.3081,-0.1436,0.0013,0.2754,-0.175,0.0471,0.0803,0.3134,0.0316,0.1406,0.2458,-0.125,-0.2472,-0.1026,0.1853,0.2713],[-0.1335,-0.2589,0.0326,0.184,-0.0443,0.1055,0.185,0.1039,0.2385,0.3148,-0.0447,-0.2056,-0.3518,0.1792,0.0854,-0.1689,-0.0474,0.3313,0.0958,-0.1422,-0.1022,0.1267,-0.1742,-0.0582,-0.0584,0.1683,0.1202,0.1477,-0.3187,0.5481,0.1623,-0.0474],[-0.0331,0.5843,0.2271,0.2078,-0.3411,0.1889,-0.007,-0.4373,-0.0794,-0.1806,-0.197,-0.1826,0.2024,-0.0372,0.447,0.0049,0.1445,0.2919,-0.2781,0.053,0.3029,-0.1042,0.2978,0.0387,0.2475,-0.1642,-0.0873,0.1211,-0.0457,0.0967,0.083,-0.3168],[-0.3208,0.0063,0.3332,0.3583,-0.0943,0.0042,-0.1525,-0.0051,-0.482,-0.1752,-0.0928,0.047,0.3545,-0.0559,0.2719,0.1369,-0.0197,0.2143,0.0535,-0.121,-0.3936,-0.0625,0.1318,0.0049,-0.1444,0.1743,-0.2131,0.0352,0.1641,0.2774,0.1614,0.2106],[0.1988,-0.1021,-0.1817,0.0214,0.3587,0.3751,-0.0117,0.0649,0.0382,0.3446,0.2728,-0.5324,0.0558,-0.0313,-0.05,-0.0287,0.1916,-0.0137,0.0392,0.0127,-0.1302,0.2071,0.2367,0.2003,-0.1872,0.1972,0.0757,0.1108,0.0643,0.1002,-0.1399,-0.0516],[-0.1492,-0.0777,0.3324,0.265,-0.1001,0.0788,-0.4488,0.2324,0.4602,-0.2922,-0.0612,-0.063,0.0153,0.0788,0.147,0.28,0.1059,-0.3538,-0.108,-0.1988,0.1754,-0.3138,0.3712,0.373,-0.0458,0.2742,0.0882,-0.2552,0.0214,-0.1912,-0.1858,0.0951],[-0.3176,0.381,0.1272,0.1005,0.2787,-0.0523,-0.1188,-0.3182,-0.0509,0.1054,-0.3056,0.3863,0.0334,0.2624,0.1183,0.263,0.2478,0.2901,0.17,0.1146,0.16,-0.1459,0.2591,-0.008,0.0295,-0.0854,0.1147,-0.1003,0.2472,0.003,0.1289,-0.3831],[-0.1123,0.1326,0.1809,-0.0204,-0.0958,0.2573,-0.2264,0.0175,-0.1809,-0.0589,-0.279,-0.0795,-0.2746,0.3016,-0.0886,-0.2072,-0.1466,-0.0955,-0.4169,-0.3692,-0.0541,-0.1513,0.0875,-0.2415,-0.2907,0.0694,-0.6903,-0.1046,-0.0814,-0.1099,-0.5432,0.0188],[0.071,-0.0469,0.014,0.1713,0.1707,0.0714,-0.3115,0.098,0.2274,-0.1542,-0.0669,0.2835,0.1494,-0.0124,0.4041,0.0933,0.2654,0.2575,-0.15,0.1635,-0.1393,-0.2122,-0.1358,-0.1033,0.0528,0.3301,-0.4002,-0.1158,-0.1269,-0.1851,0.1057,0.4278],[0.1359,-0.5631,0.2462,0.0193,0.1735,-0.0935,-0.201,0.3988,0.1222,0.2268,-0.0025,0.106,0.1978,-0.057,-0.4751,0.2723,-0.0329,-0.2057,0.2226,0.0555,0.2155,-0.3357,-0.234,-0.0509,0.5856,0.2055,0.0658,0.0727,-0.1076,0.277,-0.031,0.001],[-0.1785,-0.0123,-0.0587,-0.2987,-0.1707,-0.2435,-0.0442,-0.0029,0.0439,-0.221,-0.1026,0.3133,0.0376,0.0927,-0.0068,0.3031,0.1084,0.3694,0.1356,0.3769,0.0946,-0.4781,0.0683,-0.2824,0.5459,-0.2103,-0.1496,0.1393,-0.2726,-0.0968,0.0212,-0.1576],[-0.2604,-0.1213,0.4434,-0.2111,0.3122,0.0365,0.4238,-0.1552,-0.205,0.0058,-0.0042,-0.09,0.0802,0.1099,0.2816,-0.0171,-0.3298,0.1776,0.0803,-0.4212,0.1765,0.352,0.3812,-0.0433,0.0718,-0.124,0.1352,0.0913,-0.1396,0.4326,0.1809,-0.2747],[0.1154,-0.1482,0.2585,-0.0515,0.0589,0.3604,-0.4226,0.498,-0.182,0.181,0.0869,-0.2523,-0.1073,0.0893,0.0958,0.2284,0.1968,-0.0878,0.3166,0.2008,-0.1484,-0.048,-0.0554,0.2156,-0.0329,0.2065,0.0629,-0.019,-0.1519,0.0315,0.0644,0.1602],[0.0884,-0.0551,0.1211,0.0942,-0.1086,0.2048,-0.201,0.3351,0.2398,-0.0256,0.1677,0.2357,0.1219,-0.1057,0.1617,0.3792,-0.2318,0.072,-0.0855,0.0033,0.025,-0.2944,0.1309,0.2302,0.2287,0.1234,-0.1595,-0.0771,0.2233,-0.1112,-0.0749,-0.0894],[0.2874,-0.1143,-0.0001,-0.1022,-0.0008,-0.1631,-0.0357,0.0298,0.2061,0.2171,0.5487,-0.0049,-0.2557,-0.0511,-0.2133,-0.3877,-0.1119,-0.059,0.2633,-0.0248,-0.1588,0.4442,0.2611,-0.2155,-0.0511,0.5308,0.119,0.2392,-0.373,-0.0024,0.0377,0.3819],[-0.0187,-0.1417,-0.0231,0.0223,0.0548,0.2345,0.0152,0.3351,0.2176,-0.5357,0.1314,0.0371,-0.432,0.1091,0.2047,0.1447,0.093,0.4422,0.1113,-0.0598,0.0463,-0.103,-0.0012,-0.0655,0.3697,0.1399,0.0746,-0.0082,0.0543,-0.0401,-0.0858,-0.2377],[-0.1077,0.2311,0.6711,0.3872,-0.1636,0.1502,-0.3061,-0.0781,0.0135,-0.2003,-0.1129,0.3422,0.0522,-0.0104,-0.0008,0.1597,0.1067,0.0307,0.2425,-0.1117,-0.1701,0.0414,0.0912,0.1474,-0.2316,-0.1275,0.0291,0.0937,-0.1116,0.0955,0.1575,-0.0126],[0.1069,0.1907,-0.1363,0.283,0.159,-0.07,-0.1306,-0.0471,-0.1267,-0.3098,0.029,-0.0754,-0.0258,0.1125,0.2783,0.1466,0.2865,-0.4118,-0.3699,0.4635,0.1639,-0.1333,0.1562,0.3238,0.0026,-0.1317,-0.0368,-0.2618,0.4006,0.1986,-0.0093,0.1205],[0.1293,-0.0755,-0.2152,0.2343,-0.0653,0.0711,-0.4889,0.0892,0.0682,0.0705,0.2782,-0.0553,-0.1894,0.0882,0.1036,0.5332,0.0543,0.156,0.0665,0.1175,0.0048,0.0973,0.0271,0.1576,-0.0015,0.6218,0.0051,-0.2385,0.1386,-0.0934,-0.1938,0.1451],[0.2446,-0.0435,-0.1682,-0.0129,0.0053,0.1288,0.6252,0.2111,-0.0118,0.5059,0.3459,-0.1714,0.0728,0.0072,-0.0163,-0.5989,-0.0352,0.2538,0.2533,-0.15,0.221,0.4723,0.0908,-0.1578,-0.0048,-0.0249,0.2119,0.2847,0.1673,0.2175,-0.1027,-0.0476],[-0.1499,-0.0164,-0.1347,0.1002,-0.1161,0.2968,-0.0735,0.2908,0.1746,-0.0315,-0.0161,0.1213,0.3803,0.0139,-0.1079,-0.0842,0.253,-0.2694,-0.1452,0.1602,-0.4265,0.0227,0.131,-0.0238,0.1471,0.1924,0.1296,-0.3204,0.221,-0.069,0.0312,-0.0424],[-0.1099,0.3896,0.158,0.3439,-0.0348,-0.236,-0.1547,-0.0526,-0.1305,-0.0804,-0.3307,-0.0828,0.0234,0.1896,0.1237,0.2796,-0.042,0.3084,0.1547,-0.3332,-0.0996,0.3325,0.3306,-0.418,-0.0011,-0.1108,-0.1428,-0.1618,-0.0003,0.3032,0.1209,0.0432],[0.2274,-0.0491,0.1426,-0.2707,0.1393,-0.0071,-0.2598,0.218,-0.0817,0.4575,0.2643,-0.0862,-0.0957,-0.0399,-0.084,-0.0333,-0.0812,-0.1721,0.2079,0.0821,0.0484,0.2512,-0.1289,0.1858,0.1135,0.0347,0.0103,-0.2636,0.1032,-0.0315,-0.1949,0.3029],[0.1529,0.4425,0.1133,-0.107,0.2554,-0.1387,0.1917,0.0026,-0.0007,0.3356,0.1758,0.1944,-0.0325,0.0358,0.1931,-0.0636,-0.5058,-0.2117,-0.0617,-0.1315,0.0945,0.2935,0.031,-0.1591,0.35,-0.0019,0.3545,0.1569,0.0114,0.6994,0.0315,0.1246],[0.1107,-0.2232,0.0568,-0.0058,0.0258,-0.1194,-0.1705,-0.1582,0.2945,0.3442,0.3935,-0.2653,0.1946,-0.2628,0.0392,-0.2068,-0.1318,0.1852,-0.1206,0.0283,0.2841,0.0829,-0.0598,0.449,0.0396,0.0583,-0.0017,-0.0963,0.3076,-0.3736,-0.1118,-0.1147]];
const layer_1_bias = [[0.0804],[0.2508],[0.2175],[0.2084],[0.0985],[0.064],[0.1153],[0.1513],[0.1307],[0.2146],[0.1337],[0.098],[0.1424],[0.1161],[0.1143],[0.2259],[0.1233],[0.0988],[0.1622],[0.1439],[0.2194],[0.1773],[0.0069],[0.111],[0.0538],[0.1267],[0.2062],[0.0912],[0.1879],[0.1724],[0.0751],[0.0294]];
const layer_2_weight = [[-0.0508,-0.0128,-0.0754,-0.0264,-0.0554,0.0351,-0.0569,-0.0316,-0.0336,-0.0103,-0.0294,-0.0008,-0.0005,-0.0301,0.0022,-0.0131,-0.0243,-0.0277,-0.0185,-0.0679,-0.0451,-0.0285,-0.0023,-0.0429,-0.0185,-0.0645,-0.0481,-0.004,-0.035,-0.0568,-0.0069,-0.0144],[0.0055,-0.1954,-0.1321,0.0755,0.1485,-0.0756,0.0547,0.0704,0.0011,-0.1664,-0.0822,0.1178,-0.0889,-0.1066,-0.0711,-0.0239,0.0985,-0.1259,-0.0779,0.1015,-0.0365,0.1283,-0.0517,-0.0771,-0.0335,0.0784,0.1126,-0.0025,-0.147,0.1849,0.0338,0.07],[0.016,0.2274,0.1202,-0.1601,-0.1173,0.0181,-0.0761,-0.0962,-0.0224,0.1589,0.1627,-0.1217,0.0485,0.1472,0.1446,0.0395,-0.1113,0.1162,-0.0056,-0.106,0.026,-0.1379,0.0003,0.145,-0.0127,-0.0599,-0.175,-0.004,0.1603,-0.16,-0.0343,-0.0499],[0.0442,0.0955,-0.1671,-0.1367,0.0308,0.1,0.1477,0.0072,-0.0994,-0.1242,0.0158,-0.0057,0.1264,0.0022,-0.0162,0.156,0.1215,0.0021,-0.1371,0.1201,0.1598,-0.1201,0.0385,-0.0318,0.0943,0.1145,-0.1348,0.093,-0.1393,-0.0137,-0.0332,0.0138],[-0.0888,-0.1513,0.1105,0.1629,-0.1065,-0.0427,-0.0953,0.0013,0.1019,0.0661,-0.0836,0.0222,-0.0919,0.0226,0.0135,-0.135,-0.0773,-0.0417,0.1417,-0.1903,-0.181,0.0402,-0.0103,-0.0025,-0.0569,-0.147,0.1611,-0.0718,0.0956,-0.0916,0.0889,-0.0246]];
const layer_2_bias = [[-0.0517],[0.0193],[0.0123],[-0.0125],[0.0285]];
ネットワーク作成
線形結合
入力に重み行列をかけてバイアスを加える最も基本的なネットワークです
2次元行列の積と和を実装します
const add = (matrix1, matrix2) => {
var res = [];
for(var i = 0; i < matrix1.length; i++){
res.push([]);
for(var j = 0; j < matrix1[0].length; j++){
res[i].push(matrix1[i][j] + matrix2[i][j]);
}
}
return res;
}
const dot = (matrix1, matrix2) => {
var res = [];
var row1 = matrix1.length;
var row2 = matrix2.length;
var col1 = matrix1[0].length;
var col2 = matrix2[0].length;
for(var i1 = 0; i1 < row1; i1++){
res.push([]);
for(var i2 = 0; i2 < col2; i2++){
res[i1].push(0);
for(var i3 = 0; i3 < col1; i3++){
res[i1][i2] += matrix1[i1][i3] * matrix2[i3][i2];
}
}
}
return res;
}
活性化関数
ネットワークに非線形性をもたせるために通す関数です
ここではシグモイド関数を通した後元の入力を掛け合わせています

const activation = (matrix) => {
var res = [];
for(var i = 0; i < matrix.length; i++){
res.push([]);
for(var j = 0; j < matrix[0].length; j++){
res[i].push((1 / (1 + Math.exp(-matrix[i][j])))*matrix[i][j]);
}
}
return res;
}
ソフトマックス関数
出力の合計が1になるようにする関数
後続のArgMaxに影響しないため省略
価値最大の行動取得(ArgMax)
出力の内、どの順番の値が最も大きいかを取得
const argmax = (array) => {
return array.indexOf(Math.max(...array));
}
推論実行
状態量取得
$.getPositionで自分の位置、$.velocityで自分の速度、$.subNode("Target").getGlobalPositionでConstraintした子オブジェクトを使ってCubeの位置を取得しています。
ネットワークに入力
Netronの表示通り、線形結合・活性化関数への入力を3回行います
行動選択
出力される値が最も大きい順番を取得し、それに応じた方向に加速します
const target = $.subNode("Target");
$.onStart(() => {
$.state.tick = 0;
});
$.onPhysicsUpdate(deltaTime => {
$.state.tick += deltaTime;
if ($.state.tick > 0.5) {
$.state.tick = 0;
const pos = $.getPosition()
const targetpos = target.getGlobalPosition()
const vel = $.velocity
let x = [[targetpos.x], [targetpos.z], [pos.x], [pos.z], [vel.x], [vel.z]]
x = dot(layer_0_weight, x);
x = add(x, layer_0_bias);
x = activation(x);
x = dot(layer_1_weight, x);
x = add(x, layer_1_bias);
x = activation(x);
x = dot(layer_2_weight, x);
x = add(x, layer_2_bias);
output = argmax(x.flat());
$.log(output)
if(output == 1) $.addForce(new Vector3(0, 0, 1).multiplyScalar(10));
if(output == 2) $.addForce(new Vector3(0, 0, -1).multiplyScalar(10));
if(output == 3) $.addForce(new Vector3(-1, 0, 0).multiplyScalar(10));
if(output == 4) $.addForce(new Vector3(1, 0, 0).multiplyScalar(10));
}
});
実行結果
無事、ClusterScriptでニューラルネットワークに基づく行動決定が実現できました
— さな (@ponpopon37) December 15, 2024
スクリプト全文
const layer_0_weight = [[-0.096,1.4366,0.3164,0.0943,0.3078,-0.5185],[0.2091,-0.7532,-0.3272,0.2045,0.2063,0.1847],[-0.5981,-0.1219,0.5163,0.659,-0.6621,-0.1017],[-0.1061,-0.1457,-0.0041,0.081,0.1856,1.023],[-0.3295,0.7966,0.0499,-0.1411,-0.1873,0.054],[0.3258,-0.2165,-0.047,-0.1441,0.5739,-0.1579],[0.0392,0.0692,-0.4151,-0.1111,-0.35,-0.0861],[-0.3587,0.0875,0.7904,-0.7204,0.6337,0.0169],[-0.3514,0.5545,0.5793,-0.0334,0.1557,0.3505],[-0.0318,0.1515,0.0394,-0.5405,-0.2361,-0.5503],[-0.159,0.2712,0.3241,-0.3505,-0.1752,-0.5385],[0.3229,-0.1738,0.5206,0.4438,0.1656,0.0897],[-0.3811,-0.186,-0.016,0.4245,0.1119,0.1004],[-0.2527,-0.3572,-0.2366,-0.0064,-0.0945,-0.0859],[0.3685,0.5748,0.2209,0.6606,-0.0312,0.4726],[0.0173,-0.379,0.5886,0.1019,0.7983,0.3637],[-0.3361,-0.041,-0.4772,0.3733,0.5477,-0.0432],[0.6798,-0.2952,-0.6016,0.3443,0.1831,0.4309],[0.7618,-0.0944,-0.3625,-0.2873,0.3701,-0.1353],[-0.7142,0.3816,0.5489,0.192,0.8652,-0.2168],[1.0374,-0.0542,0.0719,-0.6429,-0.0716,0.0679],[0.4568,0.6782,-0.9494,-0.2203,-0.6093,-0.1763],[-0.8011,0.3173,-0.9123,0.9854,-0.0913,0.2167],[-0.0039,0.1403,0.3662,-0.3699,0.3986,0.007],[0.0174,0.8514,-0.0344,0.1451,-0.1854,0.0102],[-0.4555,-0.1865,0.2999,-0.263,0.3953,-0.6235],[0.645,0.6281,-0.26,0.1899,-0.4517,-0.2174],[0.337,-0.372,-0.6275,-0.3325,0.0715,0.51],[-1.1203,0.0832,0.3684,0.3314,0.0568,0.1255],[0.9006,0.0772,0.3145,-0.0318,-0.552,0.1631],[0.3902,0.4222,0.0952,0.8058,-0.2516,0.7941],[0.3756,0.3196,0.0386,-1.003,0.6501,-0.0623]];
const layer_0_bias = [[0.1587],[0.1779],[0.164],[0.2506],[0.1331],[0.0311],[0.0634],[0.2075],[0.1427],[0.3085],[0.0804],[0.1392],[0.1441],[0.2494],[0.1987],[0.1022],[0.0852],[0.2213],[0.1957],[0.0696],[0.0458],[0.2747],[0.2903],[0.1094],[0.1427],[0.1575],[0.121],[0.2342],[0.0943],[0.1853],[0.103],[0.0101]];
const layer_1_weight = [[0.1669,0.0276,0.5035,0.0557,0.0926,-0.1641,-0.1214,0.1229,0.1708,0.13,-0.1748,-0.1189,0.1365,0.4358,-0.0092,0.0806,-0.1391,-0.2066,-0.1508,0.0935,-0.1517,-0.1262,-0.2469,0.318,0.0634,0.2602,0.08,-0.3802,0.5126,0.062,0.0155,0.0419],[0.1247,0.4001,0.1158,0.1577,-0.0534,-0.0473,-0.4249,0.0848,0.1261,-0.4173,-0.2411,0.2891,0.2715,-0.1284,0.2004,0.2281,0.3338,-0.3038,-0.0526,-0.1464,-0.0146,-0.5022,0.1603,-0.0213,-0.0153,-0.0242,-0.0686,0.178,0.2717,0.2066,0.3705,-0.1121],[-0.0069,0.167,0.2708,0.3153,-0.0191,0.0087,0.1732,-0.4079,-0.4019,-0.0944,-0.1592,0.2659,0.2113,0.4688,0.1908,-0.092,0.248,0.2726,0.191,-0.2792,0.0591,0.0797,0.0049,-0.1773,-0.1406,-0.1034,-0.0057,0.548,0.0612,0.165,-0.1131,-0.2897],[0.3031,-0.0019,0.0941,-0.0971,0.1445,-0.0063,0.2546,-0.355,-0.0769,0.573,-0.1048,-0.0665,-0.4332,0.5181,0.1085,-0.4082,-0.0102,0.0145,0.2679,-0.1304,0.2112,0.1318,0.2196,0.168,0.0979,-0.1408,0.1274,0.3097,-0.3234,0.2259,0.0952,0.0677],[0.3853,0.1171,-0.0694,-0.318,-0.0029,0.054,0.0561,0.132,0.571,0.4079,-0.1027,0.0146,0.0603,0.2326,-0.0844,-0.066,0.1495,-0.3413,0.251,-0.1884,0.034,0.1129,-0.0153,-0.0615,0.1749,0.2181,-0.1408,-0.2233,0.2176,0.044,-0.1731,0.2497],[0.1553,0.0812,-0.3518,-0.4324,-0.3304,0.0925,-0.0872,0.2376,0.1261,-0.2135,-0.0948,0.1945,0.0732,-0.1342,0.1721,0.2871,0.1913,0.0835,-0.1813,0.1544,0.0607,-0.0682,-0.2363,-0.2354,-0.1679,-0.0269,0.0185,-0.137,0.0697,0.1668,0.0288,-0.1477],[-0.1348,-0.1722,0.0227,-0.045,0.2626,0.0387,-0.3373,0.2377,0.1773,0.1732,-0.0138,0.181,-0.0436,0.2557,0.0561,-0.1695,0.1276,-0.3692,-0.2382,0.383,0.1573,-0.0963,0.1446,0.0969,0.279,0.3957,-0.1427,-0.261,0.005,0.0237,0.0167,-0.2335],[0.3002,-0.0024,-0.5242,-0.0735,-0.1611,0.0601,-0.006,0.1484,0.1416,0.4736,-0.0044,-0.0495,-0.1662,0.2498,0.0928,-0.3255,-0.3081,-0.1436,0.0013,0.2754,-0.175,0.0471,0.0803,0.3134,0.0316,0.1406,0.2458,-0.125,-0.2472,-0.1026,0.1853,0.2713],[-0.1335,-0.2589,0.0326,0.184,-0.0443,0.1055,0.185,0.1039,0.2385,0.3148,-0.0447,-0.2056,-0.3518,0.1792,0.0854,-0.1689,-0.0474,0.3313,0.0958,-0.1422,-0.1022,0.1267,-0.1742,-0.0582,-0.0584,0.1683,0.1202,0.1477,-0.3187,0.5481,0.1623,-0.0474],[-0.0331,0.5843,0.2271,0.2078,-0.3411,0.1889,-0.007,-0.4373,-0.0794,-0.1806,-0.197,-0.1826,0.2024,-0.0372,0.447,0.0049,0.1445,0.2919,-0.2781,0.053,0.3029,-0.1042,0.2978,0.0387,0.2475,-0.1642,-0.0873,0.1211,-0.0457,0.0967,0.083,-0.3168],[-0.3208,0.0063,0.3332,0.3583,-0.0943,0.0042,-0.1525,-0.0051,-0.482,-0.1752,-0.0928,0.047,0.3545,-0.0559,0.2719,0.1369,-0.0197,0.2143,0.0535,-0.121,-0.3936,-0.0625,0.1318,0.0049,-0.1444,0.1743,-0.2131,0.0352,0.1641,0.2774,0.1614,0.2106],[0.1988,-0.1021,-0.1817,0.0214,0.3587,0.3751,-0.0117,0.0649,0.0382,0.3446,0.2728,-0.5324,0.0558,-0.0313,-0.05,-0.0287,0.1916,-0.0137,0.0392,0.0127,-0.1302,0.2071,0.2367,0.2003,-0.1872,0.1972,0.0757,0.1108,0.0643,0.1002,-0.1399,-0.0516],[-0.1492,-0.0777,0.3324,0.265,-0.1001,0.0788,-0.4488,0.2324,0.4602,-0.2922,-0.0612,-0.063,0.0153,0.0788,0.147,0.28,0.1059,-0.3538,-0.108,-0.1988,0.1754,-0.3138,0.3712,0.373,-0.0458,0.2742,0.0882,-0.2552,0.0214,-0.1912,-0.1858,0.0951],[-0.3176,0.381,0.1272,0.1005,0.2787,-0.0523,-0.1188,-0.3182,-0.0509,0.1054,-0.3056,0.3863,0.0334,0.2624,0.1183,0.263,0.2478,0.2901,0.17,0.1146,0.16,-0.1459,0.2591,-0.008,0.0295,-0.0854,0.1147,-0.1003,0.2472,0.003,0.1289,-0.3831],[-0.1123,0.1326,0.1809,-0.0204,-0.0958,0.2573,-0.2264,0.0175,-0.1809,-0.0589,-0.279,-0.0795,-0.2746,0.3016,-0.0886,-0.2072,-0.1466,-0.0955,-0.4169,-0.3692,-0.0541,-0.1513,0.0875,-0.2415,-0.2907,0.0694,-0.6903,-0.1046,-0.0814,-0.1099,-0.5432,0.0188],[0.071,-0.0469,0.014,0.1713,0.1707,0.0714,-0.3115,0.098,0.2274,-0.1542,-0.0669,0.2835,0.1494,-0.0124,0.4041,0.0933,0.2654,0.2575,-0.15,0.1635,-0.1393,-0.2122,-0.1358,-0.1033,0.0528,0.3301,-0.4002,-0.1158,-0.1269,-0.1851,0.1057,0.4278],[0.1359,-0.5631,0.2462,0.0193,0.1735,-0.0935,-0.201,0.3988,0.1222,0.2268,-0.0025,0.106,0.1978,-0.057,-0.4751,0.2723,-0.0329,-0.2057,0.2226,0.0555,0.2155,-0.3357,-0.234,-0.0509,0.5856,0.2055,0.0658,0.0727,-0.1076,0.277,-0.031,0.001],[-0.1785,-0.0123,-0.0587,-0.2987,-0.1707,-0.2435,-0.0442,-0.0029,0.0439,-0.221,-0.1026,0.3133,0.0376,0.0927,-0.0068,0.3031,0.1084,0.3694,0.1356,0.3769,0.0946,-0.4781,0.0683,-0.2824,0.5459,-0.2103,-0.1496,0.1393,-0.2726,-0.0968,0.0212,-0.1576],[-0.2604,-0.1213,0.4434,-0.2111,0.3122,0.0365,0.4238,-0.1552,-0.205,0.0058,-0.0042,-0.09,0.0802,0.1099,0.2816,-0.0171,-0.3298,0.1776,0.0803,-0.4212,0.1765,0.352,0.3812,-0.0433,0.0718,-0.124,0.1352,0.0913,-0.1396,0.4326,0.1809,-0.2747],[0.1154,-0.1482,0.2585,-0.0515,0.0589,0.3604,-0.4226,0.498,-0.182,0.181,0.0869,-0.2523,-0.1073,0.0893,0.0958,0.2284,0.1968,-0.0878,0.3166,0.2008,-0.1484,-0.048,-0.0554,0.2156,-0.0329,0.2065,0.0629,-0.019,-0.1519,0.0315,0.0644,0.1602],[0.0884,-0.0551,0.1211,0.0942,-0.1086,0.2048,-0.201,0.3351,0.2398,-0.0256,0.1677,0.2357,0.1219,-0.1057,0.1617,0.3792,-0.2318,0.072,-0.0855,0.0033,0.025,-0.2944,0.1309,0.2302,0.2287,0.1234,-0.1595,-0.0771,0.2233,-0.1112,-0.0749,-0.0894],[0.2874,-0.1143,-0.0001,-0.1022,-0.0008,-0.1631,-0.0357,0.0298,0.2061,0.2171,0.5487,-0.0049,-0.2557,-0.0511,-0.2133,-0.3877,-0.1119,-0.059,0.2633,-0.0248,-0.1588,0.4442,0.2611,-0.2155,-0.0511,0.5308,0.119,0.2392,-0.373,-0.0024,0.0377,0.3819],[-0.0187,-0.1417,-0.0231,0.0223,0.0548,0.2345,0.0152,0.3351,0.2176,-0.5357,0.1314,0.0371,-0.432,0.1091,0.2047,0.1447,0.093,0.4422,0.1113,-0.0598,0.0463,-0.103,-0.0012,-0.0655,0.3697,0.1399,0.0746,-0.0082,0.0543,-0.0401,-0.0858,-0.2377],[-0.1077,0.2311,0.6711,0.3872,-0.1636,0.1502,-0.3061,-0.0781,0.0135,-0.2003,-0.1129,0.3422,0.0522,-0.0104,-0.0008,0.1597,0.1067,0.0307,0.2425,-0.1117,-0.1701,0.0414,0.0912,0.1474,-0.2316,-0.1275,0.0291,0.0937,-0.1116,0.0955,0.1575,-0.0126],[0.1069,0.1907,-0.1363,0.283,0.159,-0.07,-0.1306,-0.0471,-0.1267,-0.3098,0.029,-0.0754,-0.0258,0.1125,0.2783,0.1466,0.2865,-0.4118,-0.3699,0.4635,0.1639,-0.1333,0.1562,0.3238,0.0026,-0.1317,-0.0368,-0.2618,0.4006,0.1986,-0.0093,0.1205],[0.1293,-0.0755,-0.2152,0.2343,-0.0653,0.0711,-0.4889,0.0892,0.0682,0.0705,0.2782,-0.0553,-0.1894,0.0882,0.1036,0.5332,0.0543,0.156,0.0665,0.1175,0.0048,0.0973,0.0271,0.1576,-0.0015,0.6218,0.0051,-0.2385,0.1386,-0.0934,-0.1938,0.1451],[0.2446,-0.0435,-0.1682,-0.0129,0.0053,0.1288,0.6252,0.2111,-0.0118,0.5059,0.3459,-0.1714,0.0728,0.0072,-0.0163,-0.5989,-0.0352,0.2538,0.2533,-0.15,0.221,0.4723,0.0908,-0.1578,-0.0048,-0.0249,0.2119,0.2847,0.1673,0.2175,-0.1027,-0.0476],[-0.1499,-0.0164,-0.1347,0.1002,-0.1161,0.2968,-0.0735,0.2908,0.1746,-0.0315,-0.0161,0.1213,0.3803,0.0139,-0.1079,-0.0842,0.253,-0.2694,-0.1452,0.1602,-0.4265,0.0227,0.131,-0.0238,0.1471,0.1924,0.1296,-0.3204,0.221,-0.069,0.0312,-0.0424],[-0.1099,0.3896,0.158,0.3439,-0.0348,-0.236,-0.1547,-0.0526,-0.1305,-0.0804,-0.3307,-0.0828,0.0234,0.1896,0.1237,0.2796,-0.042,0.3084,0.1547,-0.3332,-0.0996,0.3325,0.3306,-0.418,-0.0011,-0.1108,-0.1428,-0.1618,-0.0003,0.3032,0.1209,0.0432],[0.2274,-0.0491,0.1426,-0.2707,0.1393,-0.0071,-0.2598,0.218,-0.0817,0.4575,0.2643,-0.0862,-0.0957,-0.0399,-0.084,-0.0333,-0.0812,-0.1721,0.2079,0.0821,0.0484,0.2512,-0.1289,0.1858,0.1135,0.0347,0.0103,-0.2636,0.1032,-0.0315,-0.1949,0.3029],[0.1529,0.4425,0.1133,-0.107,0.2554,-0.1387,0.1917,0.0026,-0.0007,0.3356,0.1758,0.1944,-0.0325,0.0358,0.1931,-0.0636,-0.5058,-0.2117,-0.0617,-0.1315,0.0945,0.2935,0.031,-0.1591,0.35,-0.0019,0.3545,0.1569,0.0114,0.6994,0.0315,0.1246],[0.1107,-0.2232,0.0568,-0.0058,0.0258,-0.1194,-0.1705,-0.1582,0.2945,0.3442,0.3935,-0.2653,0.1946,-0.2628,0.0392,-0.2068,-0.1318,0.1852,-0.1206,0.0283,0.2841,0.0829,-0.0598,0.449,0.0396,0.0583,-0.0017,-0.0963,0.3076,-0.3736,-0.1118,-0.1147]];
const layer_1_bias = [[0.0804],[0.2508],[0.2175],[0.2084],[0.0985],[0.064],[0.1153],[0.1513],[0.1307],[0.2146],[0.1337],[0.098],[0.1424],[0.1161],[0.1143],[0.2259],[0.1233],[0.0988],[0.1622],[0.1439],[0.2194],[0.1773],[0.0069],[0.111],[0.0538],[0.1267],[0.2062],[0.0912],[0.1879],[0.1724],[0.0751],[0.0294]];
const layer_2_weight = [[-0.0508,-0.0128,-0.0754,-0.0264,-0.0554,0.0351,-0.0569,-0.0316,-0.0336,-0.0103,-0.0294,-0.0008,-0.0005,-0.0301,0.0022,-0.0131,-0.0243,-0.0277,-0.0185,-0.0679,-0.0451,-0.0285,-0.0023,-0.0429,-0.0185,-0.0645,-0.0481,-0.004,-0.035,-0.0568,-0.0069,-0.0144],[0.0055,-0.1954,-0.1321,0.0755,0.1485,-0.0756,0.0547,0.0704,0.0011,-0.1664,-0.0822,0.1178,-0.0889,-0.1066,-0.0711,-0.0239,0.0985,-0.1259,-0.0779,0.1015,-0.0365,0.1283,-0.0517,-0.0771,-0.0335,0.0784,0.1126,-0.0025,-0.147,0.1849,0.0338,0.07],[0.016,0.2274,0.1202,-0.1601,-0.1173,0.0181,-0.0761,-0.0962,-0.0224,0.1589,0.1627,-0.1217,0.0485,0.1472,0.1446,0.0395,-0.1113,0.1162,-0.0056,-0.106,0.026,-0.1379,0.0003,0.145,-0.0127,-0.0599,-0.175,-0.004,0.1603,-0.16,-0.0343,-0.0499],[0.0442,0.0955,-0.1671,-0.1367,0.0308,0.1,0.1477,0.0072,-0.0994,-0.1242,0.0158,-0.0057,0.1264,0.0022,-0.0162,0.156,0.1215,0.0021,-0.1371,0.1201,0.1598,-0.1201,0.0385,-0.0318,0.0943,0.1145,-0.1348,0.093,-0.1393,-0.0137,-0.0332,0.0138],[-0.0888,-0.1513,0.1105,0.1629,-0.1065,-0.0427,-0.0953,0.0013,0.1019,0.0661,-0.0836,0.0222,-0.0919,0.0226,0.0135,-0.135,-0.0773,-0.0417,0.1417,-0.1903,-0.181,0.0402,-0.0103,-0.0025,-0.0569,-0.147,0.1611,-0.0718,0.0956,-0.0916,0.0889,-0.0246]];
const layer_2_bias = [[-0.0517],[0.0193],[0.0123],[-0.0125],[0.0285]];
const add = (matrix1, matrix2) => {
var res = [];
for(var i = 0; i < matrix1.length; i++){
res.push([]);
for(var j = 0; j < matrix1[0].length; j++){
res[i].push(matrix1[i][j] + matrix2[i][j]);
}
}
return res;
}
const dot = (matrix1, matrix2) => {
var res = [];
var row1 = matrix1.length;
var row2 = matrix2.length;
var col1 = matrix1[0].length;
var col2 = matrix2[0].length;
for(var i1 = 0; i1 < row1; i1++){
res.push([]);
for(var i2 = 0; i2 < col2; i2++){
res[i1].push(0);
for(var i3 = 0; i3 < col1; i3++){
res[i1][i2] += matrix1[i1][i3] * matrix2[i3][i2];
}
}
}
return res;
}
const activation = (matrix) => {
var res = [];
for(var i = 0; i < matrix.length; i++){
res.push([]);
for(var j = 0; j < matrix[0].length; j++){
res[i].push((1 / (1 + Math.exp(-matrix[i][j])))*matrix[i][j]);
}
}
return res;
}
const argmax = (array) => {
return array.indexOf(Math.max(...array));
}
const target = $.subNode("Target");
$.onStart(() => {
$.state.tick = 0;
});
$.onPhysicsUpdate(deltaTime => {
$.state.tick += deltaTime;
if ($.state.tick > 0.5) {
$.state.tick = 0;
const pos = $.getPosition()
const targetpos = target.getGlobalPosition()
const vel = $.velocity
let x = [[targetpos.x], [targetpos.z], [pos.x], [pos.z], [vel.x], [vel.z]]
x = dot(layer_0_weight, x);
x = add(x, layer_0_bias);
x = activation(x);
x = dot(layer_1_weight, x);
x = add(x, layer_1_bias);
x = activation(x);
x = dot(layer_2_weight, x);
x = add(x, layer_2_bias);
$.log(x.flat())
output = argmax(x.flat());
$.log(output)
if(output == 1) $.addForce(new Vector3(0, 0, 1).multiplyScalar(10));
if(output == 2) $.addForce(new Vector3(0, 0, -1).multiplyScalar(10));
if(output == 3) $.addForce(new Vector3(-1, 0, 0).multiplyScalar(10));
if(output == 4) $.addForce(new Vector3(1, 0, 0).multiplyScalar(10));
}
});
応用例
ここまで、近くのゴールの位置に向かうだけの非常にシンプルな例で実装方法を説明しました。ここで、もう少し複雑なシチュエーションのタスクをこなすエージェントを紹介します。
壁を回りながら追いかけてくる鬼
— さな (@ponpopon37) December 15, 2024
追うわんこと逃げるわんこ
— さな (@ponpopon37) December 15, 2024
AI同士でのホッケー対決
— さな (@ponpopon37) December 15, 2024
こちらを応用して、下記ワールドには一人用対戦Botも実装されています。あえて学習量を制限して低難易度のBotを作成するなどもしています。
まだ実装できていませんが、ml-agent自体はドッジボールのような高度な対戦をさせることもできるみたいなので、是非挑戦してみたいですね
まとめ
今回は、ml-agentで学習したニューラルネットワークをClusterScriptで動かし、深層強化学習AIを実装する方法を紹介しました。clusterの対戦ゲームには、中身が良くても人を集めるハードルが高くて遊ばれにくい現状があると思いまが、こうしたAI技術を使って代用するのは1つ面白いアプローチかもしれません。
ただ、現状UnityC#とCKKの両方でゲーム環境を作成するというハードルがあるので何とかしたいところ。CCKのsignalをml-agentのAddReward関数に結び付けたりすればCCK環境で学習させられそうな気がするのですが… Unityエディタ拡張とかCCK内部に詳しい人と相談して何とかできないかな…
あと、AIの挙動の調整ももう少し上手くやりたいところ。ホッケーBotは来たパックに応答速度と精度にものを言わせて位置を合わせ続けるだけみたいな節もある…確かに最適なんだけどさ。攻撃的なプレーにボーナスあげるとかしてみても良いかも