Skip to content

Instantly share code, notes, and snippets.

@mjs2600
Last active June 13, 2017 16:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save mjs2600/b5249105b4fbee4bee1f to your computer and use it in GitHub Desktop.
Save mjs2600/b5249105b4fbee4bee1f to your computer and use it in GitHub Desktop.
Homework 5 Test Cases
import burlap.behavior.singleagent.Policy;
import burlap.behavior.singleagent.planning.ActionTransitions;
import burlap.behavior.singleagent.planning.HashedTransitionProbability;
import burlap.behavior.singleagent.planning.PlannerDerivedPolicy;
import burlap.behavior.singleagent.planning.ValueFunctionPlanner;
import burlap.behavior.singleagent.planning.commonpolicies.GreedyDeterministicQPolicy;
import burlap.behavior.statehashing.DiscreteStateHashFactory;
import burlap.behavior.statehashing.StateHashTuple;
import burlap.domain.singleagent.graphdefined.GraphDefinedDomain;
import burlap.oomdp.core.Domain;
import burlap.oomdp.core.State;
import burlap.oomdp.core.TerminalFunction;
import burlap.oomdp.singleagent.RewardFunction;
import org.junit.Test;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import static org.junit.Assert.assertTrue;
public class LongPolicyIterationTest {
@Test
public void test1() {
int numStates = 100;
int numActionsPerState = 10;
double gamma = 0.75;
int count = getCount(numStates, numActionsPerState, gamma);
System.out.println("Iterations = " + count);
assertTrue(75 < count);
}
@Test
public void test2() {
int numStates = 400;
int numActionsPerState = 2;
double gamma = 0.95;
int count = getCount(numStates, numActionsPerState, gamma);
System.out.println("Iterations = " + count);
assertTrue(300 < count);
}
@Test
public void test3() {
int numStates = 20;
int numActionsPerState = 5;
double gamma = 0.5;
int count = getCount(numStates, numActionsPerState, gamma);
System.out.println("Iterations = " + count);
assertTrue(15 < count);
}
@Test
public void test4() {
int numStates = 12;
int numActionsPerState = 10;
double gamma = 0.25;
int count = getCount(numStates, numActionsPerState, gamma);
System.out.println("Iterations = " + count);
assertTrue(9 < count);
}
@Test
public void test5() {
int numStates = 1000;
int numActionsPerState = 3;
double gamma = 0.625;
int count = getCount(numStates, numActionsPerState, gamma);
System.out.println("Iterations = " + count);
assertTrue(750 < count);
}
@Test
public void test6() {
int numStates = 1000;
int numActionsPerState = 3;
double gamma = 0.382;
int count = getCount(numStates, numActionsPerState, gamma);
System.out.println("Iterations = " + count);
assertTrue(750 < count);
}
@Test
public void test7() {
for (int numStates = 1; numStates <= 4; numStates++) {
int numActionsPerState = 2;
for (double gamma = 0.5; gamma < 1; gamma += 0.01) {
int count = getCount(numStates, numActionsPerState, gamma);
assertTrue(count >= numStates);
}
}
}
private int getCount(int numStates, int numActionsPerState, double gamma) {
LongPolicyIteration lpi = new LongPolicyIteration(numStates, numActionsPerState, gamma);
LongPolicyIterationGraderTestRunVersion lpigtrv = new LongPolicyIterationGraderTestRunVersion();
return lpigtrv.countPIIterations(
lpi.getGraphDefinedDomain().generateDomain(),
numStates,
lpi.getRF(),
lpi.getTF(),
gamma
);
}
public class LongPolicyIterationGraderTestRunVersion {
final double MAX_DELTA = 0.00001;
// domain should be generated from your GraphDefinedDomain object
protected int countPIIterations(Domain domain, int numStates, RewardFunction rf, TerminalFunction tf, double gamma) {
double maxDelta = MAX_DELTA;
MyPolicyIteration mpi = new MyPolicyIteration(domain, numStates, rf, tf, gamma, maxDelta);
State initState = GraphDefinedDomain.getState(domain, 0);
mpi.planFromState(initState);
return mpi.getNumIterations();
}
class MyPolicyIteration extends ValueFunctionPlanner {
protected double maxPIDelta;
protected double maxEvalDelta;
protected PlannerDerivedPolicy evaluativePolicy;
protected boolean foundReachableStates = false;
protected int numIterations = 0;
int numStates;
DiscreteStateHashFactory hashFactory;
public MyPolicyIteration(Domain domain, int numStates, RewardFunction rf, TerminalFunction tf,
double gamma, double maxDelta) {
this.hashFactory = new DiscreteStateHashFactory();
this.VFPInit(domain, rf, tf, gamma, this.hashFactory);
this.maxPIDelta = maxDelta;
this.maxEvalDelta = maxDelta;
this.numStates = numStates;
this.evaluativePolicy = new GreedyDeterministicQPolicy(this);
}
public void setPolicyClassToEvaluate(PlannerDerivedPolicy p){
this.evaluativePolicy = p;
}
public Policy getComputedPolicy(){
return (Policy)this.evaluativePolicy;
}
public void recomputeReachableStates(){
this.foundReachableStates = false;
}
@Override
public void planFromState(State initialState) {
int iterations = 0;
this.initializeOptionsForExpectationComputations();
if(this.performReachabilityFrom(initialState)){
double delta;
do{
StaticVFPlanner lastValueFunction = this.getCopyOfValueFunction();
this.evaluativePolicy.setPlanner(lastValueFunction);
delta = this.evaluatePolicy();
iterations++;
// DPrint.cl(this.debugCode, "Num iterations: " + iterations + "\nDelta: " + delta);
}while(delta > this.maxPIDelta);
}
this.numIterations = iterations;
}
@Override
public void resetPlannerResults(){
super.resetPlannerResults();
this.foundReachableStates = false;
this.numIterations = 0;
}
protected double evaluatePolicy(){
if(!this.foundReachableStates){
throw new RuntimeException("Cannot run VI until the reachable states have been found."+
"Use planFromState method at least once or instead.");
}
double maxChangeInPolicyEvaluation = Double.NEGATIVE_INFINITY;
double delta;
do {
delta = 0;
for(int i = 0; i < this.numStates; i++){
StateHashTuple sh = this.hashFactory.hashState(GraphDefinedDomain.getState(domain, i));
double v = this.value(sh);
double maxQ = this.performFixedPolicyBellmanUpdateOn(sh,(Policy)this.evaluativePolicy);
delta = Math.max(Math.abs(maxQ - v), delta);
}
maxChangeInPolicyEvaluation = Math.max(delta, maxChangeInPolicyEvaluation);
} while(delta >= this.maxEvalDelta);
return maxChangeInPolicyEvaluation;
}
public boolean performReachabilityFrom(State si){
StateHashTuple sih = this.stateHash(si);
//if this is not a new state and we are not required to perform a new reachability analysis,
// then this method does not need to do anything.
if(transitionDynamics.containsKey(sih) && this.foundReachableStates){
return false; //no need for additional reachability testing
}
//add to the open list
LinkedList <StateHashTuple> openList = new LinkedList<StateHashTuple>();
Set <StateHashTuple> openedSet = new HashSet<StateHashTuple>();
openList.offer(sih);
openedSet.add(sih);
while(openList.size() > 0){
StateHashTuple sh = openList.poll();
//skip this if it's already been expanded
if(transitionDynamics.containsKey(sh)){
continue;
}
mapToStateIndex.put(sh, sh);
//do not need to expand from terminal states
if(this.tf.isTerminal(sh.s)){
continue;
}
//get the transition dynamics for each action and queue up new states
List <ActionTransitions> transitions = this.getActionsTransitions(sh);
for(ActionTransitions at : transitions){
for(HashedTransitionProbability tp : at.transitions){
StateHashTuple tsh = tp.sh;
if(!openedSet.contains(tsh) && !transitionDynamics.containsKey(tsh)){
openedSet.add(tsh);
openList.offer(tsh);
}
}
}
}
this.foundReachableStates = true;
return true;
}
public int getNumIterations() {
return this.numIterations;
}
}
}
public static void main(String[] args) {
// test cases are accessed as follows:
// java LongPolicyIterationTest i
// where i is the test case number
int testCase = Integer.parseInt(args[0]);
System.out.println("Test case # " + testCase);
LongPolicyIterationTest lpit = new LongPolicyIterationTest();
switch (testCase) {
case 1:
lpit.test1();
break;
case 2:
lpit.test2();
break;
case 3:
lpit.test3();
break;
case 4:
lpit.test4();
break;
case 5:
lpit.test5();
break;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment