Skip to content

Instantly share code, notes, and snippets.

@aidancbrady
Last active March 31, 2019 21:10
Show Gist options
  • Save aidancbrady/b86b811e68469927a8ccf79cf220c3c6 to your computer and use it in GitHub Desktop.
Save aidancbrady/b86b811e68469927a8ccf79cf220c3c6 to your computer and use it in GitHub Desktop.
import java.awt.Color;
import java.awt.Graphics2D;
import java.awt.geom.Rectangle2D;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import assignment4.BasicGridWorld;
import assignment4.util.AgentPainter;
import assignment4.util.AnalysisAggregator;
import assignment4.util.LocationPainter;
import assignment4.util.MapPrinter;
import burlap.behavior.policy.BoltzmannQPolicy;
import burlap.behavior.policy.Policy;
import burlap.behavior.singleagent.EpisodeAnalysis;
import burlap.behavior.singleagent.auxiliary.StateReachability;
import burlap.behavior.singleagent.auxiliary.valuefunctionvis.ValueFunctionVisualizerGUI;
import burlap.behavior.singleagent.learning.tdmethods.QLearning;
import burlap.behavior.singleagent.planning.stochastic.policyiteration.PolicyIteration;
import burlap.behavior.singleagent.planning.stochastic.valueiteration.ValueIteration;
import burlap.behavior.valuefunction.ValueFunction;
import burlap.domain.singleagent.gridworld.GridWorldDomain;
import burlap.oomdp.core.Domain;
import burlap.oomdp.core.TerminalFunction;
import burlap.oomdp.core.objects.ObjectInstance;
import burlap.oomdp.core.states.State;
import burlap.oomdp.singleagent.GroundedAction;
import burlap.oomdp.singleagent.RewardFunction;
import burlap.oomdp.singleagent.SADomain;
import burlap.oomdp.singleagent.environment.SimulatedEnvironment;
import burlap.oomdp.singleagent.explorer.VisualExplorer;
import burlap.oomdp.statehashing.HashableStateFactory;
import burlap.oomdp.statehashing.SimpleHashableStateFactory;
import burlap.oomdp.visualizer.StateRenderLayer;
import burlap.oomdp.visualizer.StaticPainter;
import burlap.oomdp.visualizer.Visualizer;
public class ReinforcementLearning {
//These are some boolean variables that affect what will actually get executed
private static boolean visualizeInitialGridWorld = true; //Loads a GUI with the agent, walls, and goal
//runValueIteration, runPolicyIteration, and runQLearning indicate which algorithms will run in the experiment
private static boolean runValueIteration = true;
private static boolean runPolicyIteration = true;
private static boolean runQLearning = true;
//showValueIterationPolicyMap, showPolicyIterationPolicyMap, and showQLearningPolicyMap will open a GUI
//you can use to visualize the policy maps. Consider only having one variable set to true at a time
//since the pop-up window does not indicate what algorithm was used to generate the map.
private static boolean showValueIterationPolicyMap = true;
private static boolean showPolicyIterationPolicyMap = true;
private static boolean showQLearningPolicyMap = true;
private static Integer MAX_ITERATIONS = 100;
private static Integer NUM_INTERVALS = 100;
protected static int[][] easyMap = new int[][] {
{ 0, 0, 0, 1, 0 },
{ 0, 1, 0, 0, 0 },
{ 0, 1, 1, 0, 0 },
{ 0, 1, 0, 1, 0 },
{ 0, 0, 0, 0, 0 }, };
protected static int[][] hardMap = new int[][] {
{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
{ 0, 1, 1, 1, 1, 1, 1, 1, 1, 0 },
{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 },
{ 0, 1, 1, 1, 1, 1, 1, 1, 1, 0 },
{ 0, 0, 0, 1, 0, 0, 0, 1, 0, 0 },
{ 0, 1, 0, 1, 0, 1, 0, 1, 0, 0 },
{ 0, 1, 0, 0, 0, 1, 0, 0, 0, 1 }
};
// private static Integer mapLen = map.length-1;
public static void main(String[] args) {
Model easyModel = new Model(new Cell(4, 4), -1);
easyModel.setReward(new Cell(4, 4), 100);
easyModel.setReward(new Cell(2, 1), -100);
Model hardModel = new Model(new Cell(9, 9), -1);
hardModel.setReward(new Cell(9, 9), 100);
hardModel.setReward(new Cell(1, 9), -100);
hardModel.setReward(new Cell(2, 9), -100);
hardModel.setReward(new Cell(6, 7), -100);
hardModel.setReward(new Cell(7, 7), -100);
hardModel.setReward(new Cell(6, 5), -100);
hardModel.setReward(new Cell(9, 1), -100);
runAnalysis("Easy", easyMap, easyModel);
//runAnalysis("Hard", hardMap, hardModel);
}
public static void runAnalysis(String type, int[][] mdp, Model model) {
// convert to BURLAP indexing
int[][] map = MapPrinter.mapToMatrix(mdp);
int maxX = map.length-1;
int maxY = map[0].length-1;
//
BasicGridWorld gen = new BasicGridWorld(map,maxX,maxY);
Domain domain = gen.generateDomain();
State initialState = BasicGridWorld.getExampleState(domain);
SimulatedEnvironment env = new SimulatedEnvironment(domain, model, model,
initialState);
//Print the map that is being analyzed
System.out.println("/////" + type + " Grid World Analysis/////\n");
MapPrinter.printMap(MapPrinter.matrixToMap(map));
if (visualizeInitialGridWorld) {
visualizeInitialGridWorld(domain, gen, env, model);
}
CustomAnalysisRunner runner = new CustomAnalysisRunner();
if(runValueIteration){
runner.runValueIteration(gen,domain,initialState, model, model, showValueIterationPolicyMap, 0.99);
}
if(runPolicyIteration){
runner.runPolicyIteration(gen,domain,initialState, model, model, showPolicyIterationPolicyMap, 0.99);
}
if(runQLearning){
runner.runQLearning(gen,domain,initialState, model, model, env, showQLearningPolicyMap, 0.99);
}
AnalysisAggregator.printAggregateAnalysis();
}
private static void visualizeInitialGridWorld(Domain domain,
BasicGridWorld gen, SimulatedEnvironment env, Model model) {
StateRenderLayer rl = new StateRenderLayer();
rl.addStaticPainter(new CustomWallPainter(gen.getMap(), model));
rl.addObjectClassPainter("location", new LocationPainter(gen.getMap()));
rl.addObjectClassPainter("agent", new AgentPainter(gen.getMap()));
Visualizer v = new Visualizer(rl);
VisualExplorer exp = new VisualExplorer(domain, env, v);
exp.addKeyAction("w", BasicGridWorld.ACTIONNORTH);
exp.addKeyAction("s", BasicGridWorld.ACTIONSOUTH);
exp.addKeyAction("d", BasicGridWorld.ACTIONEAST);
exp.addKeyAction("a", BasicGridWorld.ACTIONWEST);
exp.setTitle("Easy Grid World");
exp.initGUI();
}
private static class Cell {
public int x;
public int y;
public Cell(int i1, int i2) {
x = i1;
y = i2;
}
@Override
public boolean equals(Object obj) {
return obj instanceof Cell && ((Cell)obj).x == x && ((Cell)obj).y == y;
}
@Override
public int hashCode() {
int code = 1;
code = 31 * code + x;
code = 31 * code + y;
return code;
}
}
public static class CustomWallPainter implements StaticPainter {
private int[][] map;
private Model model;
public CustomWallPainter(int[][] map, Model model) {
this.map = map;
this.model = model;
}
@Override
public void paint(Graphics2D g2, State s, float cWidth, float cHeight) {
//set up floats for the width and height of our domain
float fWidth = this.map.length;
float fHeight = this.map[0].length;
//determine the width of a single cell
//on our canvas such that the whole map can be painted
float width = cWidth / fWidth;
float height = cHeight / fHeight;
//pass through each cell of our map and if it's a wall, paint a black rectangle on our
//cavas of dimension widthxheight
for(int i = 0; i < this.map.length; i++){
for(int j = 0; j < this.map[0].length; j++){
boolean red = model.isRed(new Cell(i, j));
//is there a wall here?
if(this.map[i][j] == 1 || red) {
//left coordinate of cell on our canvas
float rx = i*width;
//top coordinate of cell on our canvas
//coordinate system adjustment because the java canvas
//origin is in the top left instead of the bottom right
float ry = cHeight - height - j*height;
if(red) {
g2.setColor(Color.RED);
g2.fill(new Rectangle2D.Float(rx, ry, width, height));
} else {
g2.setColor(Color.BLACK);
g2.fill(new Rectangle2D.Float(rx, ry, width, height));
}
}
}
}
}
}
private static class Model implements TerminalFunction, RewardFunction {
Cell terminalCell;
double defaultReward;
Map<Cell, Double> customRewards = new HashMap<>();
public Model(Cell terminal, double def) {
terminalCell = terminal;
defaultReward = def;
}
public void setReward(Cell cell, double reward) {
customRewards.put(cell, reward);
}
public boolean isRed(Cell cell) {
return customRewards.get(cell) != null && !cell.equals(terminalCell);
}
@Override
public double reward(State s, GroundedAction a, State sprime) {
// get location of agent in next state
ObjectInstance agent = sprime.getFirstObjectOfClass(BasicGridWorld.CLASSAGENT);
int ax = agent.getIntValForAttribute(BasicGridWorld.ATTX);
int ay = agent.getIntValForAttribute(BasicGridWorld.ATTY);
Cell cell = new Cell(ax, ay);
if(customRewards.containsKey(cell)) {
return customRewards.get(cell);
}
return defaultReward;
}
@Override
public boolean isTerminal(State s) {
// get location of agent in next state
ObjectInstance agent = s.getFirstObjectOfClass(BasicGridWorld.CLASSAGENT);
int ax = agent.getIntValForAttribute(BasicGridWorld.ATTX);
int ay = agent.getIntValForAttribute(BasicGridWorld.ATTY);
Cell cell = new Cell(ax, ay);
return cell.equals(terminalCell);
}
}
public static class CustomAnalysisRunner {
final SimpleHashableStateFactory hashingFactory = new SimpleHashableStateFactory();
public CustomAnalysisRunner(){
int increment = MAX_ITERATIONS/NUM_INTERVALS;
for(int numIterations = increment;numIterations<=MAX_ITERATIONS;numIterations+=increment ){
AnalysisAggregator.addNumberOfIterations(numIterations);
}
}
public void runValueIteration(BasicGridWorld gen, Domain domain,
State initialState, RewardFunction rf, TerminalFunction tf, boolean showPolicyMap, double discount) {
System.out.println("//Value Iteration Analysis//");
ValueIteration vi = null;
Policy p = null;
EpisodeAnalysis ea = null;
int increment = MAX_ITERATIONS/NUM_INTERVALS;
for(int numIterations = increment;numIterations<=MAX_ITERATIONS;numIterations+=increment ){
long startTime = System.nanoTime();
vi = new ValueIteration(
domain,
rf,
tf,
discount,
hashingFactory,
-1, numIterations); //Added a very high delta number in order to guarantee that value iteration occurs the max number of iterations
//for comparison with the other algorithms.
// run planning from our initial state
p = vi.planFromState(initialState);
AnalysisAggregator.addMillisecondsToFinishValueIteration((int) (System.nanoTime()-startTime)/1000000);
// evaluate the policy with one roll out visualize the trajectory
ea = p.evaluateBehavior(initialState, rf, tf);
AnalysisAggregator.addValueIterationReward(calcRewardInEpisode(ea));
AnalysisAggregator.addStepsToFinishValueIteration(ea.numTimeSteps());
}
// Visualizer v = gen.getVisualizer();
// new EpisodeSequenceVisualizer(v, domain, Arrays.asList(ea));
AnalysisAggregator.printValueIterationResults();
MapPrinter.printPolicyMap(vi.getAllStates(), p, gen.getMap());
System.out.println("\n\n");
if(showPolicyMap){
simpleValueFunctionVis((ValueFunction)vi, p, initialState, domain, hashingFactory, "Value Iteration");
}
}
public void runPolicyIteration(BasicGridWorld gen, Domain domain,
State initialState, RewardFunction rf, TerminalFunction tf, boolean showPolicyMap, double discount) {
System.out.println("//Policy Iteration Analysis//");
PolicyIteration pi = null;
Policy p = null;
EpisodeAnalysis ea = null;
int increment = MAX_ITERATIONS/NUM_INTERVALS;
for(int numIterations = increment;numIterations<=MAX_ITERATIONS;numIterations+=increment ){
long startTime = System.nanoTime();
pi = new PolicyIteration(
domain,
rf,
tf,
discount,
hashingFactory,
-1, 1, numIterations);
// run planning from our initial state
p = pi.planFromState(initialState);
AnalysisAggregator.addMillisecondsToFinishPolicyIteration((int) (System.nanoTime()-startTime)/1000000);
// evaluate the policy with one roll out visualize the trajectory
ea = p.evaluateBehavior(initialState, rf, tf);
AnalysisAggregator.addPolicyIterationReward(calcRewardInEpisode(ea));
AnalysisAggregator.addStepsToFinishPolicyIteration(ea.numTimeSteps());
}
// Visualizer v = gen.getVisualizer();
// new EpisodeSequenceVisualizer(v, domain, Arrays.asList(ea));
AnalysisAggregator.printPolicyIterationResults();
MapPrinter.printPolicyMap(getAllStates(domain,rf,tf,initialState), p, gen.getMap());
System.out.println("\n\n");
//visualize the value function and policy.
if(showPolicyMap){
simpleValueFunctionVis(pi, p, initialState, domain, hashingFactory, "Policy Iteration");
}
}
public void simpleValueFunctionVis(ValueFunction valueFunction, Policy p,
State initialState, Domain domain, HashableStateFactory hashingFactory, String title){
List<State> allStates = StateReachability.getReachableStates(initialState,
(SADomain)domain, hashingFactory);
ValueFunctionVisualizerGUI gui = GridWorldDomain.getGridWorldValueFunctionVisualization(
allStates, valueFunction, p);
gui.setTitle(title);
gui.initGUI();
}
public void runQLearning(BasicGridWorld gen, Domain domain,
State initialState, RewardFunction rf, TerminalFunction tf,
SimulatedEnvironment env, boolean showPolicyMap, double discount) {
System.out.println("//Q Learning Analysis//");
QLearning agent = null;
Policy p = null;
EpisodeAnalysis ea = null;
int increment = MAX_ITERATIONS/NUM_INTERVALS;
for(int numIterations = increment;numIterations<=MAX_ITERATIONS;numIterations+=increment ){
long startTime = System.nanoTime();
agent = new QLearning(
domain,
discount,
hashingFactory,
0.99, 0.99);
agent.setLearningPolicy(new BoltzmannQPolicy(agent, 0.5));
for (int i = 0; i < numIterations; i++) {
ea = agent.runLearningEpisode(env);
env.resetEnvironment();
}
agent.initializeForPlanning(rf, tf, 1);
p = agent.planFromState(initialState);
AnalysisAggregator.addQLearningReward(calcRewardInEpisode(ea));
AnalysisAggregator.addMillisecondsToFinishQLearning((int) (System.nanoTime()-startTime)/1000000);
AnalysisAggregator.addStepsToFinishQLearning(ea.numTimeSteps());
}
AnalysisAggregator.printQLearningResults();
MapPrinter.printPolicyMap(getAllStates(domain,rf,tf,initialState), p, gen.getMap());
System.out.println("\n\n");
//visualize the value function and policy.
if(showPolicyMap){
simpleValueFunctionVis((ValueFunction)agent, p, initialState, domain, hashingFactory, "Q-Learning");
}
}
private static List<State> getAllStates(Domain domain,
RewardFunction rf, TerminalFunction tf,State initialState){
ValueIteration vi = new ValueIteration(
domain,
rf,
tf,
0.99,
new SimpleHashableStateFactory(),
.5, 100);
vi.planFromState(initialState);
return vi.getAllStates();
}
public double calcRewardInEpisode(EpisodeAnalysis ea) {
double myRewards = 0;
//sum all rewards
for (int i = 0; i<ea.rewardSequence.size(); i++) {
myRewards += ea.rewardSequence.get(i);
}
return myRewards;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment