-
-
Save 13gav59/93ed2f8e2b99bcef7afcec9eace57938 to your computer and use it in GitHub Desktop.
ML-Agents
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using UnityEngine; | |
using MLAgents; | |
public class PlayerAcademy : Academy | |
{ | |
public float CoinSpeed { get; private set; } | |
public override void InitializeAcademy() | |
{ | |
CoinSpeed = 0; | |
FloatProperties.RegisterCallback("coin_speed", f => { | |
CoinSpeed = f; | |
}); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using UnityEngine; | |
using MLAgents; | |
public class PlayerAgent : Agent | |
{ | |
public float moveSpeed = 5f; | |
public float turnSpeed = 180f; | |
private PlayerArea playerArea; | |
private PlayerAcademy playerAcademy; | |
new private Rigidbody rigidbody; | |
public override void InitializeAgent() | |
{ | |
base.InitializeAgent(); | |
playerArea = GetComponentInParent<PlayerArea>(); | |
playerAcademy = FindObjectOfType<PlayerAcademy>(); | |
rigidbody = GetComponent<Rigidbody>(); | |
} | |
public override void AgentAction(float[] vectorAction) | |
{ | |
Debug.Log(vectorAction[0] + "\t" + vectorAction[1]); | |
float forwardAmount = vectorAction[0]; | |
float turnAmount = 0f; | |
if (vectorAction[1] == 1f) | |
{ | |
turnAmount = -1f; | |
} | |
else if (vectorAction[1] == 2f) | |
{ | |
turnAmount = 1f; | |
} | |
rigidbody.MovePosition(transform.position + transform.forward * forwardAmount * moveSpeed * Time.fixedDeltaTime); | |
transform.Rotate(transform.up * turnAmount * turnSpeed * Time.fixedDeltaTime); | |
AddReward(-1f / agentParameters.maxStep); | |
} | |
public override float[] Heuristic() | |
{ | |
float forwardAction = 0f; | |
float turnAction = 0f; | |
if (Input.GetKey(KeyCode.W)) | |
{ | |
forwardAction = 1f; | |
} | |
if (Input.GetKey(KeyCode.A)) | |
{ | |
turnAction = 1f; | |
} | |
else if (Input.GetKey(KeyCode.D)) | |
{ | |
turnAction = 2f; | |
} | |
return new float[] { forwardAction, turnAction }; | |
} | |
public override void AgentReset() | |
{ | |
playerArea.ResetArea(); | |
} | |
public override void CollectObservations() | |
{ | |
AddVectorObs(playerArea.coinList.Count); | |
AddVectorObs(transform.forward); | |
} | |
private void FixedUpdate() | |
{ | |
} | |
private void OnCollisionEnter(Collision collision) | |
{ | |
if(collision.transform.CompareTag("coin")) | |
{ | |
CollectCoin(collision.gameObject); | |
} | |
} | |
private void CollectCoin(GameObject obj) | |
{ | |
playerArea.RemoveSpecificCoin(obj); | |
AddReward(1f); | |
if(playerArea.CoinsRemaining <= 0) | |
{ | |
Done(); | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System.Collections.Generic; | |
using UnityEngine; | |
using MLAgents; | |
using TMPro; | |
public class PlayerArea : Area | |
{ | |
public PlayerAgent playerAgent; | |
public TextMeshPro rewardText; | |
public Coin coinPrefab; | |
private PlayerAcademy playerAcademy; | |
public List<GameObject> coinList; | |
public override void ResetArea() | |
{ | |
RemoveAllCoins(); | |
PlacePlayer(); | |
SpawnCoins(5); | |
} | |
public void RemoveSpecificCoin(GameObject coinObj) | |
{ | |
coinList.Remove(coinObj); | |
Destroy(coinObj); | |
} | |
public int CoinsRemaining | |
{ | |
get { return coinList.Count; } | |
} | |
public static Vector3 ChooseRandomPosition(int c, Vector3 center) | |
{ | |
if (c == 0) | |
{ | |
return center + new Vector3(UnityEngine.Random.Range(1f, 4f), 2f, UnityEngine.Random.Range(-1f, 1f)); | |
} | |
else if (c == 1) | |
{ | |
return center + new Vector3(UnityEngine.Random.Range(-4.5f, 4.5f), 1.85f, UnityEngine.Random.Range(-2.5f, 2.5f)); | |
} | |
else | |
{ | |
return center + new Vector3(0, 0, 0); | |
} | |
} | |
private void RemoveAllCoins() | |
{ | |
if (coinList != null) | |
{ | |
for (int i = 0; i < coinList.Count; i++) | |
{ | |
if (coinList[i] != null) | |
{ | |
Destroy(coinList[i]); | |
} | |
} | |
} | |
coinList = new List<GameObject>(); | |
} | |
private void PlacePlayer() | |
{ | |
Rigidbody rigidbody = playerAgent.GetComponent<Rigidbody>(); | |
rigidbody.velocity = Vector3.zero; | |
rigidbody.angularVelocity = Vector3.zero; | |
playerAgent.transform.position = ChooseRandomPosition(0, transform.position) + Vector3.up * .5f; | |
// playerAgent.transform.rotation = Quaternion.Euler(0f, UnityEngine.Random.Range(0f, 360f), 0f); | |
} | |
private void SpawnCoins(int c) | |
{ | |
for (int i = 0; i < c; i++) | |
{ | |
GameObject coinObj = Instantiate<GameObject>(coinPrefab.gameObject); | |
coinObj.transform.position = ChooseRandomPosition(1, transform.position) + Vector3.up * .5f; | |
coinObj.transform.SetParent(transform); | |
coinList.Add(coinObj); | |
} | |
} | |
private void Start() | |
{ | |
playerAcademy = FindObjectOfType<PlayerAcademy>(); | |
ResetArea(); | |
} | |
private void Update() | |
{ | |
rewardText.text = playerAgent.GetCumulativeReward().ToString("0.00"); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment