Skip to content

Instantly share code, notes, and snippets.

@nwertzberger
Created August 7, 2015 01:08
Show Gist options
  • Save nwertzberger/29c1278dfa0abf7bceec to your computer and use it in GitHub Desktop.
Save nwertzberger/29c1278dfa0abf7bceec to your computer and use it in GitHub Desktop.
public class UnoptimizedTransitionCalculator {
private static final Logger logger = LoggerFactory.getLogger(
UnoptimizedTransitionCalculator.class
);
public Policy generateNewPolicy(
Set<Action> actions,
Map<State, Double> expectedUtilities,
Map<StateAction, ? extends Transition> transitions) {
return new Policy(
expectedUtilities.keySet().parallelStream().collect(
toConcurrentMap(
(state) -> state,
(state) -> maxAction(
state, actions, transitions, expectedUtilities
)
)
)
);
}
private Action maxAction(
final State state,
final Set<Action> actions,
final Map<StateAction, ? extends Transition> transitions,
final Map<State, Double> expectedUtilities) {
return actions
.stream()
.max(
comparing(
act -> calculateExpectedUtility(
state,
transitions,
expectedUtilities,
act
)
)
).get();
}
private Double calculateExpectedUtility(
State state,
Map<StateAction, ? extends Transition> transitions,
Map<State, Double> expectedUtilities,
Action act) {
Transition t = transitions.get(new StateAction(state, act));
return expectedUtilities
.entrySet()
.stream()
.mapToDouble(entry -> t.getStateProbability(entry.getKey()) * entry.getValue())
.sum();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment