Created
October 15, 2018 06:30
-
-
Save Santara/f97fbffc881390667373514329a18c96 to your computer and use it in GitHub Desktop.
Rough and dirty implementation of Theorem 1 of Thomas et al. High Confidence Off-Policy Evaluation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def hcope_estimator(X, c, delta): | |
""" | |
X : float, size = (size_of_history, ) | |
importance weighted trajectory rewards from the behavior policy | |
c : float, size = (size_of_history, ) | |
= scalar, if all samples from X have the same threshold | |
thresholds for each random variable in x | |
delta : float, size = scalar | |
1-delta is the confidence of the estimator | |
RETURNS: lower bound for the mean, mu as per Theorem 1 of Thomas et al. High Confidence Off-Policy Evaluation | |
""" | |
X = np.asarray(X, dtype=float) | |
n = len(X) | |
if ~isinstance(c, list): | |
c = np.full((n,), c, dtype=float) | |
Y = np.asarray([min(X[i], c[i]) for i in range(len(X))], dtype=float) | |
# Empirical mean | |
EM = np.sum(Y/c)/np.sum(1/c) | |
# Second term | |
if n>1: | |
term2 = (7.*n*np.log(2/delta)) / (3*(n-1)*np.sum(1/c)) | |
else: | |
raise(ValueError("The value of 'n' must be greater than 1")) | |
# Third term | |
if n>1: | |
term3 = np.sqrt( ((2*np.log(2/delta))/(n-1)) * (n*np.sum(np.square(Y/c)) - np.square(np.sum(Y/c))) ) / np.sum(1/c) | |
else: | |
raise(ValueError("The value of 'n' must be greater than 1")) | |
# Final estimate | |
return EM - term2 - term3 | |
if __name__=="__main__": | |
testX = np.asarray([1.,2.,3.,4.]) | |
testc = 2. | |
testdelta = 0.2 | |
print(hcope_estimator(testX, testc, testdelta)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment