Create a gist now

Instantly share code, notes, and snippets.

ソートしてくれ頼む
# coding: utf-8
import random
import difflib
import json
import csv
import sort_runner as sr
GAMMA = 0.8
EPSILON_INIT = 1
class QL:
@classmethod
def with_learned_model(self, csv_path):
q = {}
with open(csv_path, 'r') as f:
reader = csv.reader(f, delimiter='\t')
for row in reader:
if not row[0] in q:
q[row[0]] = {}
q[row[0]][eval('sr.%s' % row[1])] = float(row[2])
ql = QL()
ql.q = q
return ql
def __init__(self):
self.q = {}
def learn(self):
for i in range(100):
print('episode %d' % i)
self.do_episode()
f = open('model.csv', 'w')
for s, av in self.q.items():
for a, v in av.items():
f.write('%s\t%s\t%f\n' % (s, a, v))
f.close()
def do_episode(self):
runner = sr.Runner()
epsilon = EPSILON_INIT
actions = sr.actions
while not runner.finished():
current = runner.env.current_exp
a = self._choice_action(current, actions, epsilon)
runner.step(a)
next_state = runner.env.current_exp
q_val = self._reward(runner) + (GAMMA * self._max_q(next_state, actions))
self._set_q_val(current, a, q_val)
if epsilon > 0:
epsilon -= 0.000001
def _choice_action(self, current, actions, epsilon, greedy=False, verbose=False):
if not current in self.q:
return random.choice(actions)
if greedy or random.random() > epsilon:
max_q = max(self.q[current].values())
next_action_candidates = []
for a, q in self.q[current].items():
if q == max_q:
next_action_candidates.append(a)
if verbose:
print('max_q: %f' % max_q)
return random.choice(next_action_candidates)
return random.choice(actions)
def _set_q_val(self, current, action, q_val):
if not current in self.q:
self.q[current] = {}
self.q[current][action] = q_val
def _reward(self, runner):
if runner.finished():
return 1
return 0
# def _reward(self, runner):
# r = 0
# if runner.finished():
# return 100
# else:
# for i in range(len(runner.env.collect_answer)):
# ci = runner.env.current[i]
# if ci != runner.env.old_state[i] and ci == runner.env.collect_answer[i]:
# r += 1
# return r
def _max_q(self, state, actions):
if not state in self.q:
return 0
return max(self.q[state].get(a, 0) for a in actions)
def run(self, input):
runner = sr.Runner(input)
actions = sr.actions
while not runner.finished():
a = self._choice_action(runner.env.current_exp, actions, 0, greedy=True, verbose=True)
runner.step(a, verbose=True)
print('-----------')
ql = QL()
ql.learn()
# import sys
# ql = QL.with_learned_model(sys.argv[1])
# ql.run([int(s) for s in sys.argv[2].split(',')])
# coding: utf-8
import random
import copy
TARGET_ELEMENT_SIZE = 10
class Runner:
def __init__(self, input=None):
self.env = SortEnvironment(range(TARGET_ELEMENT_SIZE), input)
self.step_num = 0
def step(self, action, verbose=False):
action.do(self.env)
if verbose:
print('step %d, %s' % (self.step_num, action))
print('%s -> %s' % (self.env.old_state, self.env.current))
self.step_num += 1
def finished(self):
return self.env.collected()
class SortEnvironment:
def __init__(self, collect_answer, input=None):
if not input:
self.current = random.sample(collect_answer, len(collect_answer))
else:
self.current = input
self.collect_answer = collect_answer
self.old_state = copy.copy(self.current)
@property
def current_exp(self):
return ','.join([str(s) for s in self.current])
@property
def old_state(self):
return self.old_state
@property
def collect_answer(self):
return self.collect_answer
def collected(self):
return self.current == self.collect_answer
def biggerThan(self, pos_a, pos_b):
return self.current[pos_a] > self.current[pos_b]
def swap(self, pos_a, pos_b):
self.old_state = copy.copy(self.current)
a = self.current[pos_a]
b = self.current[pos_b]
self.current[pos_a] = b
self.current[pos_b] = a
class SwapPosAAndPosB:
def __init__(self, pos_a, pos_b):
self.pos_a = pos_a
self.pos_b = pos_b
def do(self, env):
env.swap(self.pos_a, self.pos_b)
def __repr__(self):
return 'SwapPosAAndPosB(%d,%d)' % (self.pos_a, self.pos_b)
actions = []
for i in range(TARGET_ELEMENT_SIZE):
actions += [ SwapPosAAndPosB(i,j) for j in range(TARGET_ELEMENT_SIZE) if i != j and i > j ]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment