Last active
March 31, 2019 21:10
-
-
Save aidancbrady/b86b811e68469927a8ccf79cf220c3c6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.awt.Color; | |
import java.awt.Graphics2D; | |
import java.awt.geom.Rectangle2D; | |
import java.util.HashMap; | |
import java.util.List; | |
import java.util.Map; | |
import assignment4.BasicGridWorld; | |
import assignment4.util.AgentPainter; | |
import assignment4.util.AnalysisAggregator; | |
import assignment4.util.LocationPainter; | |
import assignment4.util.MapPrinter; | |
import burlap.behavior.policy.BoltzmannQPolicy; | |
import burlap.behavior.policy.Policy; | |
import burlap.behavior.singleagent.EpisodeAnalysis; | |
import burlap.behavior.singleagent.auxiliary.StateReachability; | |
import burlap.behavior.singleagent.auxiliary.valuefunctionvis.ValueFunctionVisualizerGUI; | |
import burlap.behavior.singleagent.learning.tdmethods.QLearning; | |
import burlap.behavior.singleagent.planning.stochastic.policyiteration.PolicyIteration; | |
import burlap.behavior.singleagent.planning.stochastic.valueiteration.ValueIteration; | |
import burlap.behavior.valuefunction.ValueFunction; | |
import burlap.domain.singleagent.gridworld.GridWorldDomain; | |
import burlap.oomdp.core.Domain; | |
import burlap.oomdp.core.TerminalFunction; | |
import burlap.oomdp.core.objects.ObjectInstance; | |
import burlap.oomdp.core.states.State; | |
import burlap.oomdp.singleagent.GroundedAction; | |
import burlap.oomdp.singleagent.RewardFunction; | |
import burlap.oomdp.singleagent.SADomain; | |
import burlap.oomdp.singleagent.environment.SimulatedEnvironment; | |
import burlap.oomdp.singleagent.explorer.VisualExplorer; | |
import burlap.oomdp.statehashing.HashableStateFactory; | |
import burlap.oomdp.statehashing.SimpleHashableStateFactory; | |
import burlap.oomdp.visualizer.StateRenderLayer; | |
import burlap.oomdp.visualizer.StaticPainter; | |
import burlap.oomdp.visualizer.Visualizer; | |
public class ReinforcementLearning { | |
//These are some boolean variables that affect what will actually get executed | |
private static boolean visualizeInitialGridWorld = true; //Loads a GUI with the agent, walls, and goal | |
//runValueIteration, runPolicyIteration, and runQLearning indicate which algorithms will run in the experiment | |
private static boolean runValueIteration = true; | |
private static boolean runPolicyIteration = true; | |
private static boolean runQLearning = true; | |
//showValueIterationPolicyMap, showPolicyIterationPolicyMap, and showQLearningPolicyMap will open a GUI | |
//you can use to visualize the policy maps. Consider only having one variable set to true at a time | |
//since the pop-up window does not indicate what algorithm was used to generate the map. | |
private static boolean showValueIterationPolicyMap = true; | |
private static boolean showPolicyIterationPolicyMap = true; | |
private static boolean showQLearningPolicyMap = true; | |
private static Integer MAX_ITERATIONS = 100; | |
private static Integer NUM_INTERVALS = 100; | |
protected static int[][] easyMap = new int[][] { | |
{ 0, 0, 0, 1, 0 }, | |
{ 0, 1, 0, 0, 0 }, | |
{ 0, 1, 1, 0, 0 }, | |
{ 0, 1, 0, 1, 0 }, | |
{ 0, 0, 0, 0, 0 }, }; | |
protected static int[][] hardMap = new int[][] { | |
{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, | |
{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, | |
{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, | |
{ 0, 1, 1, 1, 1, 1, 1, 1, 1, 0 }, | |
{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, | |
{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, | |
{ 0, 1, 1, 1, 1, 1, 1, 1, 1, 0 }, | |
{ 0, 0, 0, 1, 0, 0, 0, 1, 0, 0 }, | |
{ 0, 1, 0, 1, 0, 1, 0, 1, 0, 0 }, | |
{ 0, 1, 0, 0, 0, 1, 0, 0, 0, 1 } | |
}; | |
// private static Integer mapLen = map.length-1; | |
public static void main(String[] args) { | |
Model easyModel = new Model(new Cell(4, 4), -1); | |
easyModel.setReward(new Cell(4, 4), 100); | |
easyModel.setReward(new Cell(2, 1), -100); | |
Model hardModel = new Model(new Cell(9, 9), -1); | |
hardModel.setReward(new Cell(9, 9), 100); | |
hardModel.setReward(new Cell(1, 9), -100); | |
hardModel.setReward(new Cell(2, 9), -100); | |
hardModel.setReward(new Cell(6, 7), -100); | |
hardModel.setReward(new Cell(7, 7), -100); | |
hardModel.setReward(new Cell(6, 5), -100); | |
hardModel.setReward(new Cell(9, 1), -100); | |
runAnalysis("Easy", easyMap, easyModel); | |
//runAnalysis("Hard", hardMap, hardModel); | |
} | |
public static void runAnalysis(String type, int[][] mdp, Model model) { | |
// convert to BURLAP indexing | |
int[][] map = MapPrinter.mapToMatrix(mdp); | |
int maxX = map.length-1; | |
int maxY = map[0].length-1; | |
// | |
BasicGridWorld gen = new BasicGridWorld(map,maxX,maxY); | |
Domain domain = gen.generateDomain(); | |
State initialState = BasicGridWorld.getExampleState(domain); | |
SimulatedEnvironment env = new SimulatedEnvironment(domain, model, model, | |
initialState); | |
//Print the map that is being analyzed | |
System.out.println("/////" + type + " Grid World Analysis/////\n"); | |
MapPrinter.printMap(MapPrinter.matrixToMap(map)); | |
if (visualizeInitialGridWorld) { | |
visualizeInitialGridWorld(domain, gen, env, model); | |
} | |
CustomAnalysisRunner runner = new CustomAnalysisRunner(); | |
if(runValueIteration){ | |
runner.runValueIteration(gen,domain,initialState, model, model, showValueIterationPolicyMap, 0.99); | |
} | |
if(runPolicyIteration){ | |
runner.runPolicyIteration(gen,domain,initialState, model, model, showPolicyIterationPolicyMap, 0.99); | |
} | |
if(runQLearning){ | |
runner.runQLearning(gen,domain,initialState, model, model, env, showQLearningPolicyMap, 0.99); | |
} | |
AnalysisAggregator.printAggregateAnalysis(); | |
} | |
private static void visualizeInitialGridWorld(Domain domain, | |
BasicGridWorld gen, SimulatedEnvironment env, Model model) { | |
StateRenderLayer rl = new StateRenderLayer(); | |
rl.addStaticPainter(new CustomWallPainter(gen.getMap(), model)); | |
rl.addObjectClassPainter("location", new LocationPainter(gen.getMap())); | |
rl.addObjectClassPainter("agent", new AgentPainter(gen.getMap())); | |
Visualizer v = new Visualizer(rl); | |
VisualExplorer exp = new VisualExplorer(domain, env, v); | |
exp.addKeyAction("w", BasicGridWorld.ACTIONNORTH); | |
exp.addKeyAction("s", BasicGridWorld.ACTIONSOUTH); | |
exp.addKeyAction("d", BasicGridWorld.ACTIONEAST); | |
exp.addKeyAction("a", BasicGridWorld.ACTIONWEST); | |
exp.setTitle("Easy Grid World"); | |
exp.initGUI(); | |
} | |
private static class Cell { | |
public int x; | |
public int y; | |
public Cell(int i1, int i2) { | |
x = i1; | |
y = i2; | |
} | |
@Override | |
public boolean equals(Object obj) { | |
return obj instanceof Cell && ((Cell)obj).x == x && ((Cell)obj).y == y; | |
} | |
@Override | |
public int hashCode() { | |
int code = 1; | |
code = 31 * code + x; | |
code = 31 * code + y; | |
return code; | |
} | |
} | |
public static class CustomWallPainter implements StaticPainter { | |
private int[][] map; | |
private Model model; | |
public CustomWallPainter(int[][] map, Model model) { | |
this.map = map; | |
this.model = model; | |
} | |
@Override | |
public void paint(Graphics2D g2, State s, float cWidth, float cHeight) { | |
//set up floats for the width and height of our domain | |
float fWidth = this.map.length; | |
float fHeight = this.map[0].length; | |
//determine the width of a single cell | |
//on our canvas such that the whole map can be painted | |
float width = cWidth / fWidth; | |
float height = cHeight / fHeight; | |
//pass through each cell of our map and if it's a wall, paint a black rectangle on our | |
//cavas of dimension widthxheight | |
for(int i = 0; i < this.map.length; i++){ | |
for(int j = 0; j < this.map[0].length; j++){ | |
boolean red = model.isRed(new Cell(i, j)); | |
//is there a wall here? | |
if(this.map[i][j] == 1 || red) { | |
//left coordinate of cell on our canvas | |
float rx = i*width; | |
//top coordinate of cell on our canvas | |
//coordinate system adjustment because the java canvas | |
//origin is in the top left instead of the bottom right | |
float ry = cHeight - height - j*height; | |
if(red) { | |
g2.setColor(Color.RED); | |
g2.fill(new Rectangle2D.Float(rx, ry, width, height)); | |
} else { | |
g2.setColor(Color.BLACK); | |
g2.fill(new Rectangle2D.Float(rx, ry, width, height)); | |
} | |
} | |
} | |
} | |
} | |
} | |
private static class Model implements TerminalFunction, RewardFunction { | |
Cell terminalCell; | |
double defaultReward; | |
Map<Cell, Double> customRewards = new HashMap<>(); | |
public Model(Cell terminal, double def) { | |
terminalCell = terminal; | |
defaultReward = def; | |
} | |
public void setReward(Cell cell, double reward) { | |
customRewards.put(cell, reward); | |
} | |
public boolean isRed(Cell cell) { | |
return customRewards.get(cell) != null && !cell.equals(terminalCell); | |
} | |
@Override | |
public double reward(State s, GroundedAction a, State sprime) { | |
// get location of agent in next state | |
ObjectInstance agent = sprime.getFirstObjectOfClass(BasicGridWorld.CLASSAGENT); | |
int ax = agent.getIntValForAttribute(BasicGridWorld.ATTX); | |
int ay = agent.getIntValForAttribute(BasicGridWorld.ATTY); | |
Cell cell = new Cell(ax, ay); | |
if(customRewards.containsKey(cell)) { | |
return customRewards.get(cell); | |
} | |
return defaultReward; | |
} | |
@Override | |
public boolean isTerminal(State s) { | |
// get location of agent in next state | |
ObjectInstance agent = s.getFirstObjectOfClass(BasicGridWorld.CLASSAGENT); | |
int ax = agent.getIntValForAttribute(BasicGridWorld.ATTX); | |
int ay = agent.getIntValForAttribute(BasicGridWorld.ATTY); | |
Cell cell = new Cell(ax, ay); | |
return cell.equals(terminalCell); | |
} | |
} | |
public static class CustomAnalysisRunner { | |
final SimpleHashableStateFactory hashingFactory = new SimpleHashableStateFactory(); | |
public CustomAnalysisRunner(){ | |
int increment = MAX_ITERATIONS/NUM_INTERVALS; | |
for(int numIterations = increment;numIterations<=MAX_ITERATIONS;numIterations+=increment ){ | |
AnalysisAggregator.addNumberOfIterations(numIterations); | |
} | |
} | |
public void runValueIteration(BasicGridWorld gen, Domain domain, | |
State initialState, RewardFunction rf, TerminalFunction tf, boolean showPolicyMap, double discount) { | |
System.out.println("//Value Iteration Analysis//"); | |
ValueIteration vi = null; | |
Policy p = null; | |
EpisodeAnalysis ea = null; | |
int increment = MAX_ITERATIONS/NUM_INTERVALS; | |
for(int numIterations = increment;numIterations<=MAX_ITERATIONS;numIterations+=increment ){ | |
long startTime = System.nanoTime(); | |
vi = new ValueIteration( | |
domain, | |
rf, | |
tf, | |
discount, | |
hashingFactory, | |
-1, numIterations); //Added a very high delta number in order to guarantee that value iteration occurs the max number of iterations | |
//for comparison with the other algorithms. | |
// run planning from our initial state | |
p = vi.planFromState(initialState); | |
AnalysisAggregator.addMillisecondsToFinishValueIteration((int) (System.nanoTime()-startTime)/1000000); | |
// evaluate the policy with one roll out visualize the trajectory | |
ea = p.evaluateBehavior(initialState, rf, tf); | |
AnalysisAggregator.addValueIterationReward(calcRewardInEpisode(ea)); | |
AnalysisAggregator.addStepsToFinishValueIteration(ea.numTimeSteps()); | |
} | |
// Visualizer v = gen.getVisualizer(); | |
// new EpisodeSequenceVisualizer(v, domain, Arrays.asList(ea)); | |
AnalysisAggregator.printValueIterationResults(); | |
MapPrinter.printPolicyMap(vi.getAllStates(), p, gen.getMap()); | |
System.out.println("\n\n"); | |
if(showPolicyMap){ | |
simpleValueFunctionVis((ValueFunction)vi, p, initialState, domain, hashingFactory, "Value Iteration"); | |
} | |
} | |
public void runPolicyIteration(BasicGridWorld gen, Domain domain, | |
State initialState, RewardFunction rf, TerminalFunction tf, boolean showPolicyMap, double discount) { | |
System.out.println("//Policy Iteration Analysis//"); | |
PolicyIteration pi = null; | |
Policy p = null; | |
EpisodeAnalysis ea = null; | |
int increment = MAX_ITERATIONS/NUM_INTERVALS; | |
for(int numIterations = increment;numIterations<=MAX_ITERATIONS;numIterations+=increment ){ | |
long startTime = System.nanoTime(); | |
pi = new PolicyIteration( | |
domain, | |
rf, | |
tf, | |
discount, | |
hashingFactory, | |
-1, 1, numIterations); | |
// run planning from our initial state | |
p = pi.planFromState(initialState); | |
AnalysisAggregator.addMillisecondsToFinishPolicyIteration((int) (System.nanoTime()-startTime)/1000000); | |
// evaluate the policy with one roll out visualize the trajectory | |
ea = p.evaluateBehavior(initialState, rf, tf); | |
AnalysisAggregator.addPolicyIterationReward(calcRewardInEpisode(ea)); | |
AnalysisAggregator.addStepsToFinishPolicyIteration(ea.numTimeSteps()); | |
} | |
// Visualizer v = gen.getVisualizer(); | |
// new EpisodeSequenceVisualizer(v, domain, Arrays.asList(ea)); | |
AnalysisAggregator.printPolicyIterationResults(); | |
MapPrinter.printPolicyMap(getAllStates(domain,rf,tf,initialState), p, gen.getMap()); | |
System.out.println("\n\n"); | |
//visualize the value function and policy. | |
if(showPolicyMap){ | |
simpleValueFunctionVis(pi, p, initialState, domain, hashingFactory, "Policy Iteration"); | |
} | |
} | |
public void simpleValueFunctionVis(ValueFunction valueFunction, Policy p, | |
State initialState, Domain domain, HashableStateFactory hashingFactory, String title){ | |
List<State> allStates = StateReachability.getReachableStates(initialState, | |
(SADomain)domain, hashingFactory); | |
ValueFunctionVisualizerGUI gui = GridWorldDomain.getGridWorldValueFunctionVisualization( | |
allStates, valueFunction, p); | |
gui.setTitle(title); | |
gui.initGUI(); | |
} | |
public void runQLearning(BasicGridWorld gen, Domain domain, | |
State initialState, RewardFunction rf, TerminalFunction tf, | |
SimulatedEnvironment env, boolean showPolicyMap, double discount) { | |
System.out.println("//Q Learning Analysis//"); | |
QLearning agent = null; | |
Policy p = null; | |
EpisodeAnalysis ea = null; | |
int increment = MAX_ITERATIONS/NUM_INTERVALS; | |
for(int numIterations = increment;numIterations<=MAX_ITERATIONS;numIterations+=increment ){ | |
long startTime = System.nanoTime(); | |
agent = new QLearning( | |
domain, | |
discount, | |
hashingFactory, | |
0.99, 0.99); | |
agent.setLearningPolicy(new BoltzmannQPolicy(agent, 0.5)); | |
for (int i = 0; i < numIterations; i++) { | |
ea = agent.runLearningEpisode(env); | |
env.resetEnvironment(); | |
} | |
agent.initializeForPlanning(rf, tf, 1); | |
p = agent.planFromState(initialState); | |
AnalysisAggregator.addQLearningReward(calcRewardInEpisode(ea)); | |
AnalysisAggregator.addMillisecondsToFinishQLearning((int) (System.nanoTime()-startTime)/1000000); | |
AnalysisAggregator.addStepsToFinishQLearning(ea.numTimeSteps()); | |
} | |
AnalysisAggregator.printQLearningResults(); | |
MapPrinter.printPolicyMap(getAllStates(domain,rf,tf,initialState), p, gen.getMap()); | |
System.out.println("\n\n"); | |
//visualize the value function and policy. | |
if(showPolicyMap){ | |
simpleValueFunctionVis((ValueFunction)agent, p, initialState, domain, hashingFactory, "Q-Learning"); | |
} | |
} | |
private static List<State> getAllStates(Domain domain, | |
RewardFunction rf, TerminalFunction tf,State initialState){ | |
ValueIteration vi = new ValueIteration( | |
domain, | |
rf, | |
tf, | |
0.99, | |
new SimpleHashableStateFactory(), | |
.5, 100); | |
vi.planFromState(initialState); | |
return vi.getAllStates(); | |
} | |
public double calcRewardInEpisode(EpisodeAnalysis ea) { | |
double myRewards = 0; | |
//sum all rewards | |
for (int i = 0; i<ea.rewardSequence.size(); i++) { | |
myRewards += ea.rewardSequence.get(i); | |
} | |
return myRewards; | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment