Last active
January 2, 2020 17:55
-
-
Save aadharna/b066a948841bbdf57f9483066341a5f2 to your computer and use it in GitHub Desktop.
devo gist
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from scipy import optimize | |
import torch.nn.functional as F | |
import numpy as np | |
from tqdm import tqdm | |
from collections import OrderedDict | |
import matplotlib.pyplot as plt | |
# from IPython import display | |
from optimization.Optimizer import PyTorchObjective | |
#from baselines.common.vec_env.shmem_vec_env import ShmemVecEnv | |
#from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv | |
from utils.utils import zelda_spaces | |
from copy import deepcopy | |
import time | |
from utils.diff_evo import differential_evolution | |
import gym | |
import gvgai | |
from generator.levels.base import Generator | |
import numpy as np | |
import ctypes as c | |
from agent.NNagent import NNagent | |
from agent.base import Agent | |
from generator.env_gen_wrapper import GridGame | |
from scipy.optimize import Bounds | |
_x = NNagent(GridGame(game='zelda', | |
play_length=1000, | |
path='./levels', | |
lvl_name='1.txt', | |
mechanics=['+', 'g'], # monsters, key, door, wall | |
images=False, | |
) | |
) | |
z = PyTorchObjective(_x) | |
import devo | |
import devo.jDE | |
# num_gen = num_fn / popsize | |
start = time.time() | |
# Try increasing the popsize argument by a lot. | |
result_02 = devo.jDE.run( | |
10000, | |
z.popsize, | |
0.5, | |
0.9, | |
z.fun_c, | |
z.x0.shape[0], | |
-5.0, | |
5.0, | |
z.create_population().ctypes.data_as(c.POINTER(c.c_double)), | |
z.init_fitnesses.ctypes.data_as(c.POINTER(c.c_double)), | |
z.results_callback | |
) | |
end = time.time() - start | |
end / 3600 | |
result = [z.fun(x, len(x)).value for x in z.out_population] | |
np.array(result) | |
z.out_fitnesses.squeeze() | |
z.out_fitnesses.squeeze().min() | |
min_result = np.argmin(z.out_fitnesses.squeeze()) | |
best_score = z.out_fitnesses[min_result] | |
best_weights = z.out_population[min_result] | |
best_score, best_weights | |
z.fun(best_weights, len(best_weights)) | |
z.x0 = best_weights | |
state_dict = z.unpack_parameters(z.x0) | |
z.f.load_state_dict(state_dict) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment