フリーランスの暇な時に書く技術メモ
Genki Tech

Unity + ML-Agentsでディープラーニングを始めた記事の解説

前回の記事では、Unity + ML-Agentsで環境を作って学習させるところまで出来ました。

以前からやりたいと思っていたAI技術、ディープラーニングを遂に始めました。元々これがやりたくて、やるなら記事を書きたいと思ってWordPressやGatsbyを使い始めたので、1カ月近く回り道してしまいました。 更新が早いせいか古い情報も沢山あったため、出来るだけ最新の情報でま…

今回はもうちょっと踏み込んで、コードや設定内容の説明をしたいと思います。

ものすごくザックリ解説

次のような構造で学んでいきます。

  1. 観測:Observation ① ② ③ ④ ⑤ ⑥ ⑦ ⑧ (自分と相手の座標や速度を入力)

  2. ブレイン(頭脳)。ML-Agentsの学習機能

  3. 実行:Action ① ②(左右上下の操作指示を実行)

  4. 報酬を与える

    入力に戻って繰り返す。

もうちょっとちゃんと説明

1. 観測

この部分では、ブレインが考えるのに必要と思われる情報を入力します。

今回の例だとターゲットのXYZ座標+自分のXYZ座標+自分の速度XZ情報で8つの実数を渡しています。これがX座標だよ!みたいな意味付けは必要なく、座標情報だけ渡せばOKです。後は自分で考えてくれます。

入力する値の数は、ボールに設定しているのコンポーネント「Behavior Parameters」>「Vector Observation」>「Space Size」で指定しています。

コード上では以下の部分で必要な値をブレインに渡しています。

    public override void CollectObservations(VectorSensor sensor)
    {
        sensor.AddObservation(target.localPosition); //TargetのXYZ座標
        sensor.AddObservation(this.transform.localPosition); //RollerAgentのXYZ座標
        sensor.AddObservation(rBody.velocity.x); // RollerAgentのX速度
        sensor.AddObservation(rBody.velocity.z); // RollerAgentのZ速度
    }

2. ブレイン(頭脳)

学ばせたいことの内容によって、適切なブレインの性能や大きさを設定します。ちなみに何を考えているかは人間にはわかりません。

この設定は前回の例では、プロジェクトフォルダにあるConfigフォルダのrollerball_config.yamlに記述しています。

behaviors:
  RollerBall:
    max_steps: 1000000

前回はほとんど初期値で動かしてしまいましたが、以下のような細かい設定が可能です。

behaviors:
  RollerBall:
    trainer_type:   ppo
    hyperparameters:
      batch_size:   1024
      buffer_size:  10240
      learning_rate:        0.0003
      beta: 0.005
      epsilon:      0.2
      lambd:        0.95
    network_settings:
      normalize:    False
      hidden_units: 128
      num_layers:   2
    reward_signals:
      extrinsic:
        gamma:      0.99
        strength:   1.0
    checkpoint_interval:    500000
    max_steps:      1000000
    time_horizon:   64
    summary_freq:   50000

重要な部分として、「hidden_units」や「num_layers」は、頭脳で言うところのシナプスの数や連結数に似た、ブレインそのものの能力に影響する部分なので、適切に設定してあげる必要があるようです。

簡単なことをさせたいのにあまり大きな値にすると、頭デッカチで考えすぎて正解にたどり着くのに時間がかかる子になり、逆にあまり複雑なことを小さい頭でさせようとすると不可能になるようです。

他にも、「自分今の良かったんじゃない?」と考える間隔などの設定も出来るようです。細かい内容はML-Agentsのリファレンスなどをご覧ください。

3. 実行

この部分では、頭脳が考えた結果をオブジェクトに対する操作として出力します。 今回の例だと左右上下に入力してボールを動かすため、XYの2軸が出力となります。

出力する値の数は、ボールのコンポーネント「Behavior Parameters」>「Vector Action」>「Space Size」で指定しています。

コード上では以下の部分で、頭脳の出力を操作に変換しています。

    public override void OnActionReceived(ActionBuffers actionBuffers)
    {
        // ボールに力を加える
        rBody.AddForce(actionBuffers.ContinuousActions[0] * 10, 0, actionBuffers.ContinuousActions[1] * 10);

また、ブレインモデル未セット時にキーボードで操作出来るようになっています。 acrtionsにセットした内容がそのままOnActionReceivedのactionBuffersに渡るイメージです。

    public override void Heuristic(in ActionBuffers actionsOut)
    {
        var actions = actionsOut.ContinuousActions;
        actions[0] = Input.GetAxis("Horizontal");
        actions[1] = Input.GetAxis("Vertical");
    }

4. 報酬を与える

学ぶためには、適切な時に褒め、適切な時に叱ってあげる必要があります。人間と同じです。

褒めるにはAddRewardを呼び出します。先ほどの関数の続きで以下の様に処理しています。

    // 前回より近づいた分報酬。離れた分お叱り
    var nowDistance = Vector3.Distance(this.transform.localPosition, target.localPosition);
    AddReward((previousDistance - nowDistance) * 0.1f);
    previousDistance = nowDistance;

    //時間の経過とともにお叱り
    AddReward(-0.01f);

    // 箱に接触した場合は報酬を与えてリセット。
    if (Vector3.Distance(this.transform.localPosition, target.localPosition) < 1.0f)
    {
        AddReward(10.0f);
        EndEpisode();
    }

    // 落下した時場合はお叱りを与えてリセット。
    if (this.transform.localPosition.y < 0)
    {
        AddReward(-13.0f);
        EndEpisode();
    }
}

こんな感じで報酬を与えています。

  1. ターゲットを捕らえたら褒める!
  2. 落っこちたら叱る!
  3. 時間がたつとグチグチ責める。(これは必要ないかもしれない)
  4. 近づいたらちょっと褒め、離れたらちょっとけなす。(これをやると学ぶのが早そう)

これらはバランスが大切で、例えば時間経過でめっちゃ叱ると、叱られたくないので最速で落っこちに行きます。このサイクルによって、出来るだけ褒めてもらえるようなパターンを見つけるように頭脳が学び続けてくれます。

その他気が付いたこと

エピソードの長さと成長速度

EndEpisodeで内部の報酬値がいったんリセットされる様なので、エピソード終了地点の報酬が関係あるのかな?と最初は思っていましたが、エピソードの終了タイミングと成長具合はあまり関係が無いようです。

試しにエピソードを一度も終わらせずに見ていても(もちろん落ちたら元の位置に戻しますが)、同じ速度で成長していきました。

時間の経過で叱る必要はない

もともと短時間で出来るだけ多くの報酬を得るように成長してくれるため、時間の経過で罰を与えるような実装はしない様がよさそうです。(そうする事で無駄に自分から落ちに行くように成長してしまう、という事を避けられます)