Created
June 12, 2016 22:48
-
-
Save mmajewsk/fb0cff5be99d6384bef229036a5e03b2 to your computer and use it in GitHub Desktop.
PYthons Poor Users Parameter Toolbox
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## created by Hawker (https://github.com/hawkerpl) | |
## inspired by https://github.com/SmokinCaterpillar/pypet | |
## | |
import itertools | |
def cartesian_generator(somedict): | |
values = somedict.values() | |
keys = somedict.keys() | |
for line in itertools.product(*values): | |
yield dict(zip(keys,line)) | |
import os | |
import pandas as pd | |
class EnvBase(object): | |
def __init__(self,trajectory="",dump_dir=""): | |
self.name = trajectory | |
self.dump_dir = dump_dir | |
self.base_dir = os.path.join(self.dump_dir,self.name) | |
if not os.path.exists(self.base_dir): | |
os.makedirs(self.base_dir) | |
self.params_file = os.path.join(self.base_dir,trajectory+"_params.py") | |
self.data_file = os.path.join(self.base_dir,trajectory+"_result.csv") | |
@property | |
def data(self): | |
return pd.read_csv(self.data_file) | |
def write_data_header(self): | |
with open(self.data_file,'w') as f: | |
f.write('iteration,result') | |
def write_params(self,params): | |
with open(self.params_file, 'w') as f: | |
f.write("params="+str(params)) | |
def create_header(self,params): | |
cpy = dict(params) | |
self.header = pd.DataFrame(cpy).columns | |
def add_result(self,i, output): | |
with open(self.data_file,"a") as f: | |
row_data = [i] + output | |
row_data_string = [str(element) for element in row_data] | |
result_line = ','.join(row_data_string) | |
f.write('\n'+result_line) | |
class Environment(object): | |
def __init__(self, trajectory_name="",dump_dir=""): | |
self.env_base = EnvBase(trajectory_name,dump_dir) | |
self.last_iteration = None | |
def add_params(self, params): | |
self.params = params | |
self.env_base.write_params(params) | |
def add_generator(self,generator): | |
self.generator = generator(self.params) | |
def run(self,function): | |
self.env_base.write_data_header() | |
for i, row in enumerate(self.generator): | |
result = function(row) | |
self.env_base.add_result(i, result) | |
@property | |
def last_iteration_number(self): | |
if self.last_iteration != None: | |
return self.last_iteration | |
else: | |
df = self.env_base.data | |
last_row = df['iteration'][-1:].values[0] | |
self.last_iteration = last_row | |
return self.last_iteration | |
def rerun(self,function): | |
last_iteration_number = self.last_iteration_number | |
for i, row in enumerate(self.generator): | |
if i < last_iteration_number: | |
continue | |
else: | |
result = function(row) | |
self.env_base.add_result(i, result) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment