Skip to content

Instantly share code, notes, and snippets.

@aadharna
Last active January 2, 2020 17:55
Show Gist options
  • Save aadharna/b066a948841bbdf57f9483066341a5f2 to your computer and use it in GitHub Desktop.
Save aadharna/b066a948841bbdf57f9483066341a5f2 to your computer and use it in GitHub Desktop.
devo gist
import torch
from scipy import optimize
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from collections import OrderedDict
import matplotlib.pyplot as plt
# from IPython import display
from optimization.Optimizer import PyTorchObjective
#from baselines.common.vec_env.shmem_vec_env import ShmemVecEnv
#from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
from utils.utils import zelda_spaces
from copy import deepcopy
import time
from utils.diff_evo import differential_evolution
import gym
import gvgai
from generator.levels.base import Generator
import numpy as np
import ctypes as c
from agent.NNagent import NNagent
from agent.base import Agent
from generator.env_gen_wrapper import GridGame
from scipy.optimize import Bounds
_x = NNagent(GridGame(game='zelda',
play_length=1000,
path='./levels',
lvl_name='1.txt',
mechanics=['+', 'g'], # monsters, key, door, wall
images=False,
)
)
z = PyTorchObjective(_x)
import devo
import devo.jDE
# num_gen = num_fn / popsize
start = time.time()
# Try increasing the popsize argument by a lot.
result_02 = devo.jDE.run(
10000,
z.popsize,
0.5,
0.9,
z.fun_c,
z.x0.shape[0],
-5.0,
5.0,
z.create_population().ctypes.data_as(c.POINTER(c.c_double)),
z.init_fitnesses.ctypes.data_as(c.POINTER(c.c_double)),
z.results_callback
)
end = time.time() - start
end / 3600
result = [z.fun(x, len(x)).value for x in z.out_population]
np.array(result)
z.out_fitnesses.squeeze()
z.out_fitnesses.squeeze().min()
min_result = np.argmin(z.out_fitnesses.squeeze())
best_score = z.out_fitnesses[min_result]
best_weights = z.out_population[min_result]
best_score, best_weights
z.fun(best_weights, len(best_weights))
z.x0 = best_weights
state_dict = z.unpack_parameters(z.x0)
z.f.load_state_dict(state_dict)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment