Skip to content

Instantly share code, notes, and snippets.

@ArztSamuel
Last active May 19, 2022 08:43
Show Gist options
  • Save ArztSamuel/499e617844ca4ce6e222183bd23752f0 to your computer and use it in GitHub Desktop.
Save ArztSamuel/499e617844ca4ce6e222183bd23752f0 to your computer and use it in GitHub Desktop.
Most important parts of the Agent code used for the project of https://youtu.be/VMp6pq6_QjI
public class ParkingCarAgent : Agent
{
[SerializeField]
private Transform TargetParkingSpot;
[SerializeField]
// = Reward every 'interval' units getting closer
private float DistanceRewardInterval = 3f;
// Thresholds defining when the task is complete
[SerializeField]
private float DistanceThreshold = 2;
[SerializeField]
private float RotationThreshold = 20;
[SerializeField]
private float SpeedTheshold = 5f;
// Bounds the agent may not leave
[SerializeField]
private Bounds AllowedBounds;
private DistanceSensor[] distanceSensors;
...
public override void CollectObservations()
{
base.CollectObservations();
// Agent position, y rotation and velocity
Vector3 normalizedAgentPosition = GetNormalizedPosition(this.transform.position);
AddVectorObs(carPhysics.CurrentSpeed);
AddVectorObs(normalizedAgentPosition.x);
AddVectorObs(normalizedAgentPosition.z);
Vector3 normalizedAgentRotation = GetNormalizedRotation(this.transform.rotation);
AddVectorObs(normalizedAgentRotation.y);
// Target position / y rotation
Vector3 normalizedTargetPosition = GetNormalizedPosition(TargetParkingSpot.position);
AddVectorObs(normalizedTargetPosition.x - normalizedAgentPosition.x);
AddVectorObs(normalizedTargetPosition.z - normalizedAgentPosition.z);
Vector3 normalizedTargetRotation = GetNormalizedRotation(TargetParkingSpot.rotation);
AddVectorObs(normalizedTargetRotation.y - normalizedAgentRotation.y);
// Add all sensor readings
foreach (DistanceSensor sensor in distanceSensors)
{
sensor.UpdateSensorReadings();
AddVectorObs(sensor.NormalizedDistance);
}
}
public override void AgentAction(float[] vectorAction, string textAction)
{
base.AgentAction(vectorAction, textAction);
if (IsDone())
return;
// Action Inputs, length 3:
// [0]: Throttle, positive remapped to range [0, 1]
// [0]: Braking, negative remapped to range [0, 1]
// [1]: Turning, directly used as input
carPhysics.CurrentThrottle = Mathf.Max(0, vectorAction[0]);
carPhysics.CurrentBraking = Mathf.Max(0, -vectorAction[0]);
carPhysics.CurrentTurning = vectorAction[1];
// Reward for getting closer; Note: could use sqrDistance here for performance
float distanceToTarget = Vector3.Distance(this.transform.position, TargetParkingSpot.transform.position);
if (distanceToTarget < previousDistance)
{
if ((int)(distanceToTarget / DistanceRewardInterval) < (int)(previousDistance / DistanceRewardInterval))
AddReward(0.02f);
previousDistance = distanceToTarget;
}
else
{
// Note: '* 2' is a hard coded value here, which I introduced after tuning the penalty to occur less frequently than
// the reward, in order to not 'scare' the AI of performing corrective maneuvers where it has to first increase the
// distance to the target parking spot.
if ((int)(distanceToTarget / (DistanceRewardInterval * 2)) > (int)(previousDistance / (DistanceRewardInterval * 2)))
{
if (Verbose)
Debug.Log("Distance based penalty");
AddReward(-0.04f);
previousDistance = distanceToTarget;
}
}
// Check task completion (= position and rotation lower than threshold)
float rotationDiff = Quaternion.Angle(this.transform.rotation, TargetParkingSpot.rotation);
if (distanceToTarget <= DistanceThreshold)
{
// Angle wrap-around
if (rotationDiff > 90)
rotationDiff = 180 - rotationDiff;
if (Mathf.Abs(carPhysics.CurrentSpeed) <= SpeedTheshold)
{
// Determine how well (= how parallel) the AI parked
float reward = 1;
if (rotationDiff > RotationThreshold)
reward = 1 - GetNormalizedValue(rotationDiff, RotationThreshold, 90);
AddReward(reward);
Done();
return;
}
}
if (!AllowedBounds.Contains(new Vector3Int((int)transform.position.x, (int)transform.position.y, (int)transform.position.z)))
{
AddReward(-1.0f);
Done();
return;
}
}
private Vector3 GetNormalizedPosition(in Vector3 position)
{
float normalizedX = GetNormalizedValue(position.x, AllowedBounds.min.x, AllowedBounds.max.x);
float normalizedY = GetNormalizedValue(position.y, AllowedBounds.min.y, AllowedBounds.max.y);
float normalizedZ = GetNormalizedValue(position.z, AllowedBounds.min.z, AllowedBounds.max.z);
return new Vector3(normalizedX, normalizedY, normalizedZ);
}
private Vector3 GetNormalizedRotation(in Quaternion rotation)
{
float normalizedX = GetNormalizedValue(rotation.eulerAngles.x, 0, 360);
float normalizedY = GetNormalizedValue(rotation.eulerAngles.y, 0, 360);
float normalizedZ = GetNormalizedValue(rotation.eulerAngles.z, 0, 360);
return new Vector3(normalizedX, normalizedY, normalizedZ);
}
private float GetNormalizedValue(float currentValue, float minValue, float maxValue)
{
return (currentValue - minValue) / (maxValue - minValue);
}
void OnCollisionEnter(Collision collision)
{
if (collision.collider.gameObject.GetComponent<Knockable>() || collision.collider.gameObject.GetComponentInParent<ParkingCar>())
AddReward(-0.12f);
}
...
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment