Created
March 15, 2017 21:40
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from scipy.optimize import minimize, basinhopping, differential_evolution | |
import numpy as np | |
from copy import deepcopy | |
from skopt import Optimizer | |
from skopt.learning import GaussianProcessRegressor, RandomForestRegressor, ExtraTreesRegressor | |
from skopt.learning.gaussian_process.kernels import Matern | |
from skopt import space | |
from sklearn.svm import SVR | |
class MultiTaskOptProb(): | |
def __init__(self, context_features=False, max_subtasks=6): | |
self.context_features = context_features | |
self.space = [ | |
space.Real(-5.0, 5.0), | |
space.Real(-5.0, 5.0), | |
space.Real(-5.0, 5.0), | |
] | |
self.context = [ | |
space.Categorical(range(max_subtasks)) | |
] | |
if self.context_features: | |
self.context += [ | |
space.Real(0.0,2.0), | |
space.Real(0.0,2.0), | |
] | |
self.idx = 0 | |
def reset(self): | |
X = np.random.randn(256, 5) | |
w = np.random.rand(5) | |
Y = np.dot(X, w) | |
pw = np.random.uniform(1.0, 2.0) | |
nz = np.random.uniform(0.1, 1.0) | |
Y = np.sign(Y) * (np.abs(Y) ** pw) | |
Y = Y + np.random.randn(len(Y))*nz | |
I = np.random.rand(len(X)) < 0.6 | |
X, Xv = X[I], X[~I] | |
Y, Yv = Y[I], Y[~I] | |
self.data = X, Y, Xv, Yv | |
# return task specific features | |
if self.context_features: | |
return [pw, nz] | |
else: | |
return [] | |
def step(self, P): | |
params = {k:10**p for k,p in zip(['C', 'gamma', 'epsilon'], P)} | |
model = SVR(**params) | |
X, Y, Xv, Yv = self.data | |
model.fit(X, Y) | |
score = model.score(Xv, Yv) | |
return -score | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'--surrogate', nargs="?", default="GP", type=str, help="Type of surrogate to use.") | |
args = parser.parse_args() | |
exp_typ = args.surrogate | |
import json | |
import random | |
from multiprocessing import Pool | |
MAX_SUBPROBLEM_STEPS=100 | |
MAX_SIMILAR_TASKS=7 | |
REPEATS = 8 | |
USE_CONTEXT_FEATURES = False | |
recalculate = True | |
if exp_typ == "GP": | |
all_classes = [ | |
lambda: GaussianProcessRegressor(kernel=Matern()), | |
#lambda: RandomForestRegressor(), | |
#lambda: ExtraTreesRegressor(), | |
] | |
elif exp_typ == "RFR": | |
all_classes = [ | |
lambda: RandomForestRegressor(), | |
] | |
elif exp_typ == "ETR": | |
all_classes = [ | |
lambda: ExtraTreesRegressor(), | |
] | |
results_file = 'results_'+ exp_typ +'.json' | |
#results_file = 'results_100x10x100.json' | |
FROM_SCRATCH = "From scratch" | |
KNOWLEDGE_TRANSFER = "Knowledge transfer" | |
results = {} | |
if recalculate: | |
for USE_CONTEXT_FEATURES in [False]: | |
results[USE_CONTEXT_FEATURES] = {} | |
for solver_class in all_classes: | |
SKS, KTS = [], [] | |
def repeat_opt_eval(seed): | |
np.random.seed(seed) | |
eg = MultiTaskOptProb( | |
context_features=USE_CONTEXT_FEATURES, | |
max_subtasks=MAX_SIMILAR_TASKS | |
) | |
dims = eg.space | |
ctx = eg.context | |
all_subtraces = {} | |
SC, KT = [], [] | |
for k in [FROM_SCRATCH, KNOWLEDGE_TRANSFER]: | |
""" | |
solver = Optimizer( | |
dims, | |
solver_class(), | |
acq_optimizer="sampling", | |
context_dimensions=ctx, | |
n_random_starts=10 | |
) | |
""" | |
task = MultiTaskOptProb( | |
context_features=USE_CONTEXT_FEATURES, | |
max_subtasks=MAX_SIMILAR_TASKS | |
) | |
best_points = [] | |
for similar_task in range(MAX_SIMILAR_TASKS): | |
print("Episode " + str(similar_task)+", "+str(seed)) | |
trace = [] | |
best_y = np.inf | |
feats = task.reset() | |
context = [similar_task]+feats | |
mode_kn_tr = (k == KNOWLEDGE_TRANSFER) | |
if True: | |
solver = Optimizer( | |
dims, | |
solver_class(), | |
acq_optimizer="sampling" | |
) | |
totry = deepcopy(best_points) | |
for subp_idx in range(MAX_SUBPROBLEM_STEPS): | |
task_index = similar_task if subp_idx < MAX_SUBPROBLEM_STEPS -1 else min(similar_task+1, MAX_SIMILAR_TASKS-1) | |
nxt_ctx = deepcopy(context) | |
nxt_ctx[0] = task_index | |
#print(context,nxt_ctx) | |
if (not mode_kn_tr) or (len(totry) == 0): | |
P = solver.ask() | |
else: | |
P = totry.pop() | |
solver._n_random_starts = max(0, solver._n_random_starts -1 ) | |
try: | |
v = task.step(P) | |
except BaseException as ex: | |
v = 1.0 | |
print ex | |
if best_y > v: | |
best_y = v | |
best_x = P | |
print("#" + str(similar_task) + "_" + str(subp_idx) + ", "+ str(best_y)) | |
try: | |
if True: | |
solver.tell(P, v) | |
else: | |
solver.tell(P, v, ctx=context, next_ctx=nxt_ctx) | |
except BaseException as ex: | |
print ex | |
trace.append(float(best_y)) | |
if k == KNOWLEDGE_TRANSFER: | |
KT.append(deepcopy(trace)) | |
else: | |
SC.append(deepcopy(trace)) | |
best_points.append(deepcopy(best_x)) | |
return SC, KT | |
pool = Pool() | |
# p = [repeat_opt_eval(v) for v in range(REPEATS)] | |
p = pool.map(repeat_opt_eval, range(REPEATS)) | |
for SC, KT in p: | |
SKS.append(SC) | |
KTS.append(KT) | |
results[USE_CONTEXT_FEATURES][solver_class().__class__.__name__] = { | |
FROM_SCRATCH:SKS, | |
KNOWLEDGE_TRANSFER:KTS, | |
} | |
else: | |
results = json.load(open(results_file)) | |
with open(results_file, 'w') as f: | |
json.dump(results, f) | |
colors = ['red', 'blue'] | |
name_maps = {KNOWLEDGE_TRANSFER:KNOWLEDGE_TRANSFER, FROM_SCRATCH:FROM_SCRATCH} | |
def visualize_results(aresults): | |
import matplotlib.pyplot as plt | |
w = len(aresults.keys()) | |
h = len(aresults[aresults.keys()[0]].keys()) | |
idx = 0 | |
for TF in aresults.keys(): | |
for B in aresults[TF].keys(): | |
idx += 1 | |
results = aresults[TF][B] | |
plt.subplot(w,h,idx) | |
for k, c in zip(results.keys(), colors): | |
#y = np.mean(results[k], axis=0) | |
y = np.array(results[k]) | |
#y = np.mean(y, axis=(0,1)) # first iteration of knowledge transfer | |
y = np.mean(y[:,-1,:], axis=0) # first iteration of knowledge transfer | |
#y = y[0,4,:] | |
x = range(len(y)) | |
plt.scatter(x, y, c=c, label=name_maps[k]) | |
plt.xlabel('Iteration') | |
plt.ylabel('Avg. objective') | |
plt.title(("Task features and number," if (TF == "true") else "Task number, ")+str(B)) | |
plt.grid() | |
plt.legend() | |
plt.show() | |
visualize_results(results) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment