Skip to content

Instantly share code, notes, and snippets.

@tylerlindell
Last active September 26, 2017 18:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tylerlindell/292391f87a0b7c0784aac77fac4fd3b4 to your computer and use it in GitHub Desktop.
Save tylerlindell/292391f87a0b7c0784aac77fac4fd3b4 to your computer and use it in GitHub Desktop.
updated files from Unity NEAT for getting the bouncing cubes
/*
* LocomotionController is responsible for each individual Smart Unit
*/
using System.Collections;
using System.Collections.Generic;
using SharpNeat.Phenomes;
using UnityEngine;
public class LocomotionController : UnitController
{
private Rigidbody rBody;
private bool IsRunning;
private IBlackBox box;
private bool toggleTargetTracking = true;//this is set through the TargetController component on the Target GameObject
public float fitness;
public GameObject target;
public float force = 5f;
//control the values of our velocity toward the target
public float toVelocity = 2.5f;
public float maxVelocity = 15.0f;
public float maxForce = 40.0f;
public float gain = 5f;
// Use this for initialization
void Start()
{
rBody = GetComponent<Rigidbody>();
//we get this value each time a new bobber is instantiated
target = GameObject.Find("Target");
//find out if we want to gravitate toward the target or not
toggleTargetTracking = target.GetComponent<TargetController>().isTracked;
}
// FixedUpdate called every fixed framerate frame, if the MonoBehaviour is enabled.
//FixedUpdate should be used instead of Update when dealing with Rigidbody.
void FixedUpdate()
{
//verify we are running
if (IsRunning)
{
//determin our heading
Vector3 heading = target.transform.position - transform.position;
//get our distace from our target
float distance = heading.magnitude;
//if we are interested in moving in direction of the target we can turn this on
if (toggleTargetTracking)
{
// calc a target vel proportional to distance (clamped to maxVel)
Vector3 tgtVelocity = Vector3.ClampMagnitude(toVelocity * heading, maxVelocity);
// calculate the velocity error
Vector3 error = tgtVelocity - rBody.velocity;
// calc a force proportional to the error (clamped to maxForce)
Vector3 lateralForce = Vector3.ClampMagnitude(gain * error, maxForce);
rBody.AddForce(lateralForce);
}
// initialize input and output signal arrays
ISignalArray inputArr = box.InputSignalArray;
ISignalArray outputArr = box.OutputSignalArray;
//use vertical position as input
//this is so the Neural Net (NN) knows where the cube is in vertical space
//it can use this value to judge when a thrust should be triggered
inputArr[0] = transform.position.y;
//apply force as long as the cube is within a certain distance from a given value
//the given value is decided on by the Neural Net (NN) based on input values
if (transform.position.y < (float)outputArr[0])
{
//inform the NN that a force is being applied to the cube
inputArr[1] = 1;
//allow the NN to decide on how much force is going to be applied with the value from `outputArr[1]`
rBody.AddForce(Vector3.up * (float)outputArr[1] * force, ForceMode.Impulse);
//allow the NN to decide on which direction angular velocity is going to be
//applied with the values from `outputArr[2]` through `outputArr[4]`
rBody.angularVelocity = new Vector3((float)outputArr[2], (float)outputArr[3], (float)outputArr[4]);
}
else
{
//inform the NN that zero force is being applied to the cube
inputArr[1] = 0;
}
//send the fitness of each cube to the NN so it always has a gauge of how its doing at all times
//not just at the end of a training session
inputArr[2] = fitness;
box.Activate();
//send a fraction with distance in the denominator
//this is so we have a value that increases while our distance gets smaller
//(eg. if distance is 100, our fraction is 1/100)
//(eg. if distance is 1, our fraction is 1/1 or just a value of 1)
if(distance > 0){//cannot divide by zero and cannot be closer to something than 0
AddFitness(Mathf.Abs(1 / distance));
}
}
}
public override void Activate(IBlackBox box)
{
this.box = box;
this.IsRunning = true;
}
public override float GetFitness()
{
var fit = fitness;//cache the fitness value
fitness = 0;//reset fitness value each time we start a new training cycle
if (fit < 0)
fit = 0;
return fit;
}
public override void Stop()
{
this.IsRunning = false;
}
void AddFitness(float fit)
{
//increment our fitness score on every frame by the fit value
fitness += fit;
}
}
/*
* Optimizer is responsible the training sessions (starting, stopping, and evaluating)
*/
using UnityEngine;
using System.Collections;
using SharpNeat.Phenomes;
using System.Collections.Generic;
using SharpNeat.EvolutionAlgorithms;
using SharpNeat.Genomes.Neat;
using System;
using System.Xml;
using System.IO;
public class Optimizer : MonoBehaviour {
//set the number of inputs we are going to use in this Neural Net
const int NUM_INPUTS = 3;
//set the number of outputs we plan to use
const int NUM_OUTPUTS = 5;
public int Trials;//number of trials for each generation
public float TrialDuration;//how long a trial lasts
public float StoppingFitness;//at what fitness should we max out at
bool EARunning;//are we training right now?
string popFileSavePath, champFileSavePath;
private bool isSaved = false;//saved flag
Dictionary<IBlackBox, UnitController> ControllerMap = new Dictionary<IBlackBox, UnitController>();
private DateTime startTime;//start time of training
private float timeLeft;//amount of time remaining in this training session
private float accum;//number of frames accumulated
private int frames;//frame count
private float updateInterval = 12;//how often we want to update the frame rate
private uint Generation;
private double Fitness;
SimpleExperiment experiment;
static NeatEvolutionAlgorithm<NeatGenome> _ea;
public GameObject Target;//the target we want our bobbers chasing
public GameObject Unit;//the smart object
public float distaceTargetAllowed = 1;//the closest we can get to target until we stop travelling towards it
public bool doTrain = true;//toggle the start and stop training buttons
public bool canThrottleFPS = false;//toggle whether or not the system can adjust the FPS to get more than the fpsMin frames at a lower timescale
public float fpsMin = 10;//set the number of minimum frames per second allowed before throttling
public bool showDebugger = false;//toggle debug logger
public float timeScale;//used for setting or monitoring the current timescale
public string trainedObjectName = "bob";//name of file to use / create
// Use this for initialization
void Start () {
accum = 0;
Utility.DebugLog = showDebugger;
experiment = new SimpleExperiment();
XmlDocument xmlConfig = new XmlDocument();
TextAsset textAsset = (TextAsset)Resources.Load("experiment.config");
xmlConfig.LoadXml(textAsset.text);
experiment.SetOptimizer(this);
experiment.Initialize(trainedObjectName + " Experiment", xmlConfig.DocumentElement, NUM_INPUTS, NUM_OUTPUTS);
champFileSavePath = Application.dataPath + string.Format("/{0}/{1}.champ.xml", "AIAgents", trainedObjectName);
popFileSavePath = Application.dataPath + string.Format("/{0}/{1}.pop.xml", "AIAgents", trainedObjectName);
if (Utility.DebugLog)
{
Utility.Log(champFileSavePath);
}
}
// Update is called once per frame
void Update()
{
Utility.DebugLog = showDebugger;
if (Time.timeScale != timeScale)
{
Time.timeScale = timeScale;
}
if (canThrottleFPS && EARunning)
{
timeLeft -= Time.deltaTime;
accum += (Time.timeScale / Time.smoothDeltaTime) * 1.75f; //this provides a much more accurate number with the 1.75 multiplier
++frames;
if (timeLeft <= 0.0)
{
var fps = accum / frames;
timeLeft = updateInterval;
accum = 0.0f;
frames = 0;
if (Utility.DebugLog)
{
Utility.Log(champFileSavePath);
}
if (fps < fpsMin)//update the time scale to maximize FPS
{
timeScale = timeScale - 1;
Time.timeScale = timeScale;
if (Utility.DebugLog)
{
Utility.Log("Lowering time scale to " + Time.timeScale);
}
}
}
}
//make sure we don't save more than once if the frames are still running and our generation has not updated yet
if(isSaved && ((Generation - 1) % 50 == 0))
{
isSaved = false;
}
//autosave
if (EARunning && (Generation % 50 == 0) && !isSaved)
{
Save();
isSaved = true;
}
}
public void StartEA()
{
//update the location of the Target
Target.transform.position = new Vector3(UnityEngine.Random.Range(-distaceTargetAllowed, distaceTargetAllowed), Target.transform.position.y, UnityEngine.Random.Range(-distaceTargetAllowed, distaceTargetAllowed));
if (Utility.DebugLog)
{
Utility.Log("Starting PhotoTaxis experiment");
// Utility.Log("Loading: " + popFileLoadPath);
}
_ea = experiment.CreateEvolutionAlgorithm(popFileSavePath);
startTime = DateTime.Now;
_ea.UpdateEvent += new EventHandler(ea_UpdateEvent);
_ea.PausedEvent += new EventHandler(ea_PauseEvent);
Generation = _ea.CurrentGeneration;
var evoSpeed = timeScale;
// Time.fixedDeltaTime = 0.045f;
Time.timeScale = evoSpeed;
_ea.StartContinue();
EARunning = true;
}
void ea_UpdateEvent(object sender, EventArgs e)
{
//update the location of the Target
Target.transform.position = new Vector3(UnityEngine.Random.Range(-distaceTargetAllowed, distaceTargetAllowed), Target.transform.position.y, UnityEngine.Random.Range(-distaceTargetAllowed, distaceTargetAllowed));
if (Utility.DebugLog)
{
Utility.Log(string.Format("gen={0:N0} bestFitness={1:N6}",
_ea.CurrentGeneration, _ea.Statistics._maxFitness));
}
Fitness = _ea.Statistics._maxFitness;
Generation = _ea.CurrentGeneration;
if (Utility.DebugLog)
{
// Utility.Log(string.Format("Moving average: {0}, N: {1}", _ea.Statistics._bestFitnessMA.Mean, _ea.Statistics._bestFitnessMA.Length));
}
}
void ea_PauseEvent(object sender, EventArgs e)
{
timeScale = 1;
if (Utility.DebugLog)
{
Utility.Log("Done ea'ing (and neat'ing)");
}
Save();
DateTime endTime = DateTime.Now;
if (Utility.DebugLog)
{
Utility.Log("Total time elapsed: " + (endTime - startTime));
}
System.IO.StreamReader stream = new System.IO.StreamReader(popFileSavePath);
EARunning = false;
}
public void StopEA()
{
if (_ea != null && _ea.RunState == SharpNeat.Core.RunState.Running)
{
_ea.Stop();
}
}
/// <summary>
/// save changes
/// </summary>
private void Save()
{
XmlWriterSettings _xwSettings = new XmlWriterSettings();
_xwSettings.Indent = true;
// Save genomes to xml file.
DirectoryInfo dirInf = new DirectoryInfo(Application.dataPath);
if (!dirInf.Exists)
{
if (Utility.DebugLog)
{
Debug.Log("Creating subdirectory");
}
dirInf.Create();
}
using (XmlWriter xw = XmlWriter.Create(popFileSavePath, _xwSettings))
{
experiment.SavePopulation(xw, _ea.GenomeList);
}
// Also save the best genome
using (XmlWriter xw = XmlWriter.Create(champFileSavePath, _xwSettings))
{
experiment.SavePopulation(xw, new NeatGenome[] { _ea.CurrentChampGenome });
}
}
public void Evaluate(IBlackBox box)
{
GameObject obj = Instantiate(Unit, Unit.transform.position, Unit.transform.rotation) as GameObject;
UnitController controller = obj.GetComponent<UnitController>();
ControllerMap.Add(box, controller);
controller.Activate(box);
}
public void StopEvaluation(IBlackBox box)
{
UnitController ct = ControllerMap[box];
Destroy(ct.gameObject);
}
public void RunBest()
{
timeScale = 1;
NeatGenome genome = null;
// Try to load the genome from the XML document.
try
{
using (XmlReader xr = XmlReader.Create(champFileSavePath))
genome = NeatGenomeXmlIO.ReadCompleteGenomeList(xr, false, (NeatGenomeFactory)experiment.CreateGenomeFactory())[0];
}
catch (Exception e1)
{
Debug.LogError(" Error loading genome from file!\nLoading aborted.\n" + e1.Message + "\nin: " + champFileSavePath);
return;
}
// Get a genome decoder that can convert genomes to phenomes.
var genomeDecoder = experiment.CreateGenomeDecoder();
// Decode the genome into a phenome (neural network).
var phenome = genomeDecoder.Decode(genome);
GameObject obj = Instantiate(Unit, Unit.transform.position, Unit.transform.rotation) as GameObject;
UnitController controller = obj.GetComponent<UnitController>();
ControllerMap.Add(phenome, controller);
controller.Activate(phenome);
}
public float GetFitness(IBlackBox box)
{
//update the location of the Target
Target.transform.position = new Vector3(UnityEngine.Random.Range(-distaceTargetAllowed, distaceTargetAllowed), Target.transform.position.y, UnityEngine.Random.Range(-distaceTargetAllowed, distaceTargetAllowed));
if (ControllerMap.ContainsKey(box))
{
return ControllerMap[box].GetFitness();
}
return 0;
}
void OnGUI()
{
if (doTrain)
{
if (GUI.Button(new Rect(10, 10, 100, 40), "Start EA"))
{
StartEA();
}
if (GUI.Button(new Rect(10, 60, 100, 40), "Stop EA"))
{
StopEA();
}
}
if (GUI.Button(new Rect(10, 110, 100, 40), "Run best"))
{
RunBest();
}
GUI.Button(new Rect(10, Screen.height - 70, 125, 60), string.Format("Generation: {0}\nFitness: {1:0.00}", Generation, Fitness));
}
}
/*
* TargetController is responsible for the Target Game Object
*/
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
public class TargetController : MonoBehaviour
{
public bool isTracked = true;
}
/*
* TargetFrameRate does frame jacking to help with faster training sessions
*/
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
public class TargetFrameRate : MonoBehaviour{
void Awake()
{
//frame jacking so we can traing through as many generations as possible very fast
Application.targetFrameRate = 300;
QualitySettings.vSyncCount = 0;
}
// Update is called once per frame
void Update () {
//frame jacking so we can traing through as many generations as possible very fast
Application.targetFrameRate = 300;
QualitySettings.vSyncCount = 0;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment