Skip to content

Instantly share code, notes, and snippets.

@13gav59
Created April 24, 2020 15:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save 13gav59/93ed2f8e2b99bcef7afcec9eace57938 to your computer and use it in GitHub Desktop.
Save 13gav59/93ed2f8e2b99bcef7afcec9eace57938 to your computer and use it in GitHub Desktop.
ML-Agents
using UnityEngine;
using MLAgents;
public class PlayerAcademy : Academy
{
public float CoinSpeed { get; private set; }
public override void InitializeAcademy()
{
CoinSpeed = 0;
FloatProperties.RegisterCallback("coin_speed", f => {
CoinSpeed = f;
});
}
}
using UnityEngine;
using MLAgents;
public class PlayerAgent : Agent
{
public float moveSpeed = 5f;
public float turnSpeed = 180f;
private PlayerArea playerArea;
private PlayerAcademy playerAcademy;
new private Rigidbody rigidbody;
public override void InitializeAgent()
{
base.InitializeAgent();
playerArea = GetComponentInParent<PlayerArea>();
playerAcademy = FindObjectOfType<PlayerAcademy>();
rigidbody = GetComponent<Rigidbody>();
}
public override void AgentAction(float[] vectorAction)
{
Debug.Log(vectorAction[0] + "\t" + vectorAction[1]);
float forwardAmount = vectorAction[0];
float turnAmount = 0f;
if (vectorAction[1] == 1f)
{
turnAmount = -1f;
}
else if (vectorAction[1] == 2f)
{
turnAmount = 1f;
}
rigidbody.MovePosition(transform.position + transform.forward * forwardAmount * moveSpeed * Time.fixedDeltaTime);
transform.Rotate(transform.up * turnAmount * turnSpeed * Time.fixedDeltaTime);
AddReward(-1f / agentParameters.maxStep);
}
public override float[] Heuristic()
{
float forwardAction = 0f;
float turnAction = 0f;
if (Input.GetKey(KeyCode.W))
{
forwardAction = 1f;
}
if (Input.GetKey(KeyCode.A))
{
turnAction = 1f;
}
else if (Input.GetKey(KeyCode.D))
{
turnAction = 2f;
}
return new float[] { forwardAction, turnAction };
}
public override void AgentReset()
{
playerArea.ResetArea();
}
public override void CollectObservations()
{
AddVectorObs(playerArea.coinList.Count);
AddVectorObs(transform.forward);
}
private void FixedUpdate()
{
}
private void OnCollisionEnter(Collision collision)
{
if(collision.transform.CompareTag("coin"))
{
CollectCoin(collision.gameObject);
}
}
private void CollectCoin(GameObject obj)
{
playerArea.RemoveSpecificCoin(obj);
AddReward(1f);
if(playerArea.CoinsRemaining <= 0)
{
Done();
}
}
}
using System.Collections.Generic;
using UnityEngine;
using MLAgents;
using TMPro;
public class PlayerArea : Area
{
public PlayerAgent playerAgent;
public TextMeshPro rewardText;
public Coin coinPrefab;
private PlayerAcademy playerAcademy;
public List<GameObject> coinList;
public override void ResetArea()
{
RemoveAllCoins();
PlacePlayer();
SpawnCoins(5);
}
public void RemoveSpecificCoin(GameObject coinObj)
{
coinList.Remove(coinObj);
Destroy(coinObj);
}
public int CoinsRemaining
{
get { return coinList.Count; }
}
public static Vector3 ChooseRandomPosition(int c, Vector3 center)
{
if (c == 0)
{
return center + new Vector3(UnityEngine.Random.Range(1f, 4f), 2f, UnityEngine.Random.Range(-1f, 1f));
}
else if (c == 1)
{
return center + new Vector3(UnityEngine.Random.Range(-4.5f, 4.5f), 1.85f, UnityEngine.Random.Range(-2.5f, 2.5f));
}
else
{
return center + new Vector3(0, 0, 0);
}
}
private void RemoveAllCoins()
{
if (coinList != null)
{
for (int i = 0; i < coinList.Count; i++)
{
if (coinList[i] != null)
{
Destroy(coinList[i]);
}
}
}
coinList = new List<GameObject>();
}
private void PlacePlayer()
{
Rigidbody rigidbody = playerAgent.GetComponent<Rigidbody>();
rigidbody.velocity = Vector3.zero;
rigidbody.angularVelocity = Vector3.zero;
playerAgent.transform.position = ChooseRandomPosition(0, transform.position) + Vector3.up * .5f;
// playerAgent.transform.rotation = Quaternion.Euler(0f, UnityEngine.Random.Range(0f, 360f), 0f);
}
private void SpawnCoins(int c)
{
for (int i = 0; i < c; i++)
{
GameObject coinObj = Instantiate<GameObject>(coinPrefab.gameObject);
coinObj.transform.position = ChooseRandomPosition(1, transform.position) + Vector3.up * .5f;
coinObj.transform.SetParent(transform);
coinList.Add(coinObj);
}
}
private void Start()
{
playerAcademy = FindObjectOfType<PlayerAcademy>();
ResetArea();
}
private void Update()
{
rewardText.text = playerAgent.GetCumulativeReward().ToString("0.00");
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment