Last active
June 13, 2017 16:38
-
-
Save mjs2600/b5249105b4fbee4bee1f to your computer and use it in GitHub Desktop.
Homework 5 Test Cases
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import burlap.behavior.singleagent.Policy; | |
import burlap.behavior.singleagent.planning.ActionTransitions; | |
import burlap.behavior.singleagent.planning.HashedTransitionProbability; | |
import burlap.behavior.singleagent.planning.PlannerDerivedPolicy; | |
import burlap.behavior.singleagent.planning.ValueFunctionPlanner; | |
import burlap.behavior.singleagent.planning.commonpolicies.GreedyDeterministicQPolicy; | |
import burlap.behavior.statehashing.DiscreteStateHashFactory; | |
import burlap.behavior.statehashing.StateHashTuple; | |
import burlap.domain.singleagent.graphdefined.GraphDefinedDomain; | |
import burlap.oomdp.core.Domain; | |
import burlap.oomdp.core.State; | |
import burlap.oomdp.core.TerminalFunction; | |
import burlap.oomdp.singleagent.RewardFunction; | |
import org.junit.Test; | |
import java.util.HashSet; | |
import java.util.LinkedList; | |
import java.util.List; | |
import java.util.Set; | |
import static org.junit.Assert.assertTrue; | |
public class LongPolicyIterationTest { | |
@Test | |
public void test1() { | |
int numStates = 100; | |
int numActionsPerState = 10; | |
double gamma = 0.75; | |
int count = getCount(numStates, numActionsPerState, gamma); | |
System.out.println("Iterations = " + count); | |
assertTrue(75 < count); | |
} | |
@Test | |
public void test2() { | |
int numStates = 400; | |
int numActionsPerState = 2; | |
double gamma = 0.95; | |
int count = getCount(numStates, numActionsPerState, gamma); | |
System.out.println("Iterations = " + count); | |
assertTrue(300 < count); | |
} | |
@Test | |
public void test3() { | |
int numStates = 20; | |
int numActionsPerState = 5; | |
double gamma = 0.5; | |
int count = getCount(numStates, numActionsPerState, gamma); | |
System.out.println("Iterations = " + count); | |
assertTrue(15 < count); | |
} | |
@Test | |
public void test4() { | |
int numStates = 12; | |
int numActionsPerState = 10; | |
double gamma = 0.25; | |
int count = getCount(numStates, numActionsPerState, gamma); | |
System.out.println("Iterations = " + count); | |
assertTrue(9 < count); | |
} | |
@Test | |
public void test5() { | |
int numStates = 1000; | |
int numActionsPerState = 3; | |
double gamma = 0.625; | |
int count = getCount(numStates, numActionsPerState, gamma); | |
System.out.println("Iterations = " + count); | |
assertTrue(750 < count); | |
} | |
@Test | |
public void test6() { | |
int numStates = 1000; | |
int numActionsPerState = 3; | |
double gamma = 0.382; | |
int count = getCount(numStates, numActionsPerState, gamma); | |
System.out.println("Iterations = " + count); | |
assertTrue(750 < count); | |
} | |
@Test | |
public void test7() { | |
for (int numStates = 1; numStates <= 4; numStates++) { | |
int numActionsPerState = 2; | |
for (double gamma = 0.5; gamma < 1; gamma += 0.01) { | |
int count = getCount(numStates, numActionsPerState, gamma); | |
assertTrue(count >= numStates); | |
} | |
} | |
} | |
private int getCount(int numStates, int numActionsPerState, double gamma) { | |
LongPolicyIteration lpi = new LongPolicyIteration(numStates, numActionsPerState, gamma); | |
LongPolicyIterationGraderTestRunVersion lpigtrv = new LongPolicyIterationGraderTestRunVersion(); | |
return lpigtrv.countPIIterations( | |
lpi.getGraphDefinedDomain().generateDomain(), | |
numStates, | |
lpi.getRF(), | |
lpi.getTF(), | |
gamma | |
); | |
} | |
public class LongPolicyIterationGraderTestRunVersion { | |
final double MAX_DELTA = 0.00001; | |
// domain should be generated from your GraphDefinedDomain object | |
protected int countPIIterations(Domain domain, int numStates, RewardFunction rf, TerminalFunction tf, double gamma) { | |
double maxDelta = MAX_DELTA; | |
MyPolicyIteration mpi = new MyPolicyIteration(domain, numStates, rf, tf, gamma, maxDelta); | |
State initState = GraphDefinedDomain.getState(domain, 0); | |
mpi.planFromState(initState); | |
return mpi.getNumIterations(); | |
} | |
class MyPolicyIteration extends ValueFunctionPlanner { | |
protected double maxPIDelta; | |
protected double maxEvalDelta; | |
protected PlannerDerivedPolicy evaluativePolicy; | |
protected boolean foundReachableStates = false; | |
protected int numIterations = 0; | |
int numStates; | |
DiscreteStateHashFactory hashFactory; | |
public MyPolicyIteration(Domain domain, int numStates, RewardFunction rf, TerminalFunction tf, | |
double gamma, double maxDelta) { | |
this.hashFactory = new DiscreteStateHashFactory(); | |
this.VFPInit(domain, rf, tf, gamma, this.hashFactory); | |
this.maxPIDelta = maxDelta; | |
this.maxEvalDelta = maxDelta; | |
this.numStates = numStates; | |
this.evaluativePolicy = new GreedyDeterministicQPolicy(this); | |
} | |
public void setPolicyClassToEvaluate(PlannerDerivedPolicy p){ | |
this.evaluativePolicy = p; | |
} | |
public Policy getComputedPolicy(){ | |
return (Policy)this.evaluativePolicy; | |
} | |
public void recomputeReachableStates(){ | |
this.foundReachableStates = false; | |
} | |
@Override | |
public void planFromState(State initialState) { | |
int iterations = 0; | |
this.initializeOptionsForExpectationComputations(); | |
if(this.performReachabilityFrom(initialState)){ | |
double delta; | |
do{ | |
StaticVFPlanner lastValueFunction = this.getCopyOfValueFunction(); | |
this.evaluativePolicy.setPlanner(lastValueFunction); | |
delta = this.evaluatePolicy(); | |
iterations++; | |
// DPrint.cl(this.debugCode, "Num iterations: " + iterations + "\nDelta: " + delta); | |
}while(delta > this.maxPIDelta); | |
} | |
this.numIterations = iterations; | |
} | |
@Override | |
public void resetPlannerResults(){ | |
super.resetPlannerResults(); | |
this.foundReachableStates = false; | |
this.numIterations = 0; | |
} | |
protected double evaluatePolicy(){ | |
if(!this.foundReachableStates){ | |
throw new RuntimeException("Cannot run VI until the reachable states have been found."+ | |
"Use planFromState method at least once or instead."); | |
} | |
double maxChangeInPolicyEvaluation = Double.NEGATIVE_INFINITY; | |
double delta; | |
do { | |
delta = 0; | |
for(int i = 0; i < this.numStates; i++){ | |
StateHashTuple sh = this.hashFactory.hashState(GraphDefinedDomain.getState(domain, i)); | |
double v = this.value(sh); | |
double maxQ = this.performFixedPolicyBellmanUpdateOn(sh,(Policy)this.evaluativePolicy); | |
delta = Math.max(Math.abs(maxQ - v), delta); | |
} | |
maxChangeInPolicyEvaluation = Math.max(delta, maxChangeInPolicyEvaluation); | |
} while(delta >= this.maxEvalDelta); | |
return maxChangeInPolicyEvaluation; | |
} | |
public boolean performReachabilityFrom(State si){ | |
StateHashTuple sih = this.stateHash(si); | |
//if this is not a new state and we are not required to perform a new reachability analysis, | |
// then this method does not need to do anything. | |
if(transitionDynamics.containsKey(sih) && this.foundReachableStates){ | |
return false; //no need for additional reachability testing | |
} | |
//add to the open list | |
LinkedList <StateHashTuple> openList = new LinkedList<StateHashTuple>(); | |
Set <StateHashTuple> openedSet = new HashSet<StateHashTuple>(); | |
openList.offer(sih); | |
openedSet.add(sih); | |
while(openList.size() > 0){ | |
StateHashTuple sh = openList.poll(); | |
//skip this if it's already been expanded | |
if(transitionDynamics.containsKey(sh)){ | |
continue; | |
} | |
mapToStateIndex.put(sh, sh); | |
//do not need to expand from terminal states | |
if(this.tf.isTerminal(sh.s)){ | |
continue; | |
} | |
//get the transition dynamics for each action and queue up new states | |
List <ActionTransitions> transitions = this.getActionsTransitions(sh); | |
for(ActionTransitions at : transitions){ | |
for(HashedTransitionProbability tp : at.transitions){ | |
StateHashTuple tsh = tp.sh; | |
if(!openedSet.contains(tsh) && !transitionDynamics.containsKey(tsh)){ | |
openedSet.add(tsh); | |
openList.offer(tsh); | |
} | |
} | |
} | |
} | |
this.foundReachableStates = true; | |
return true; | |
} | |
public int getNumIterations() { | |
return this.numIterations; | |
} | |
} | |
} | |
public static void main(String[] args) { | |
// test cases are accessed as follows: | |
// java LongPolicyIterationTest i | |
// where i is the test case number | |
int testCase = Integer.parseInt(args[0]); | |
System.out.println("Test case # " + testCase); | |
LongPolicyIterationTest lpit = new LongPolicyIterationTest(); | |
switch (testCase) { | |
case 1: | |
lpit.test1(); | |
break; | |
case 2: | |
lpit.test2(); | |
break; | |
case 3: | |
lpit.test3(); | |
break; | |
case 4: | |
lpit.test4(); | |
break; | |
case 5: | |
lpit.test5(); | |
break; | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment