Last active
November 5, 2018 13:56
-
-
Save yoheitaonishi/4f530100fc822137194c480e76325734 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import random | |
import itertools | |
import scipy.misc | |
import matplotlib.pyplot as plt | |
class gameOb(): | |
def __init__(self,coordinates,size,intensity,channel,reward,name): | |
self.x = coordinates[0] | |
self.y = coordinates[1] | |
self.size = size | |
self.intensity = intensity | |
self.channel = channel | |
self.reward = reward | |
self.name = name | |
class gameEnv(): | |
def __init__(self,partial,size): | |
self.sizeX = size | |
self.sizeY = size | |
self.actions = 4 | |
self.objects = [] | |
self.partial = partial | |
a = self.reset() | |
plt.imshow(a,interpolation="nearest") | |
def reset(self): | |
self.objects = [] | |
hero = gameOb(self.newPosition(),1,1,2,None,'hero') | |
self.objects.append(hero) | |
bug = gameOb(self.newPosition(),1,1,1,1,'goal') | |
self.objects.append(bug) | |
hole = gameOb(self.newPosition(),1,1,0,-1,'fire') | |
self.objects.append(hole) | |
bug2 = gameOb(self.newPosition(),1,1,1,1,'goal') | |
self.objects.append(bug2) | |
hole2 = gameOb(self.newPosition(),1,1,0,-1,'fire') | |
self.objects.append(hole2) | |
bug3 = gameOb(self.newPosition(),1,1,1,1,'goal') | |
self.objects.append(bug3) | |
bug4 = gameOb(self.newPosition(),1,1,1,1,'goal') | |
self.objects.append(bug4) | |
state = self.renderEnv() | |
self.state = state | |
return state | |
def moveChar(self,direction): | |
# 0 - up, 1 - down, 2 - left, 3 - right | |
hero = self.objects[0] | |
heroX = hero.x | |
heroY = hero.y | |
penalize = 0. | |
if direction == 0 and hero.y >= 1: | |
hero.y -= 1 | |
if direction == 1 and hero.y <= self.sizeY-2: | |
hero.y += 1 | |
if direction == 2 and hero.x >= 1: | |
hero.x -= 1 | |
if direction == 3 and hero.x <= self.sizeX-2: | |
hero.x += 1 | |
if hero.x == heroX and hero.y == heroY: | |
penalize = 0.0 | |
self.objects[0] = hero | |
return penalize | |
def newPosition(self): | |
iterables = [ range(self.sizeX), range(self.sizeY)] | |
points = [] | |
for t in itertools.product(*iterables): | |
points.append(t) | |
currentPositions = [] | |
for objectA in self.objects: | |
if (objectA.x,objectA.y) not in currentPositions: | |
currentPositions.append((objectA.x,objectA.y)) | |
for pos in currentPositions: | |
points.remove(pos) | |
location = np.random.choice(range(len(points)),replace=False) | |
return points[location] | |
def checkGoal(self): | |
others = [] | |
for obj in self.objects: | |
if obj.name == 'hero': | |
hero = obj | |
else: | |
others.append(obj) | |
ended = False | |
for other in others: | |
if hero.x == other.x and hero.y == other.y: | |
self.objects.remove(other) | |
if other.reward == 1: | |
self.objects.append(gameOb(self.newPosition(),1,1,1,1,'goal')) | |
else: | |
self.objects.append(gameOb(self.newPosition(),1,1,0,-1,'fire')) | |
return other.reward,False | |
if ended == False: | |
return 0.0,False | |
def renderEnv(self): | |
#a = np.zeros([self.sizeY,self.sizeX,3]) | |
a = np.ones([self.sizeY+2,self.sizeX+2,3]) | |
a[1:-1,1:-1,:] = 0 | |
hero = None | |
for item in self.objects: | |
a[item.y+1:item.y+item.size+1,item.x+1:item.x+item.size+1,item.channel] = item.intensity | |
if item.name == 'hero': | |
hero = item | |
if self.partial == True: | |
a = a[hero.y:hero.y+3,hero.x:hero.x+3,:] | |
b = scipy.misc.imresize(a[:,:,0],[84,84,1],interp='nearest') | |
c = scipy.misc.imresize(a[:,:,1],[84,84,1],interp='nearest') | |
d = scipy.misc.imresize(a[:,:,2],[84,84,1],interp='nearest') | |
a = np.stack([b,c,d],axis=2) | |
return a | |
def step(self,action): | |
penalty = self.moveChar(action) | |
reward,done = self.checkGoal() | |
state = self.renderEnv() | |
if reward == None: | |
print(done) | |
print(reward) | |
print(penalty) | |
return state,(reward+penalty),done | |
else: | |
return state,(reward+penalty),done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import random | |
import tensorflow as tf | |
import matplotlib.pyplot as plt | |
import scipy.misc | |
import os | |
import csv | |
import itertools | |
import tensorflow.contrib.slim as slim | |
#This is a simple function to reshape our game frames. | |
def processState(state1): | |
return np.reshape(state1,[21168]) | |
#These functions allows us to update the parameters of our target network with those of the primary network. | |
def updateTargetGraph(tfVars,tau): | |
total_vars = len(tfVars) | |
op_holder = [] | |
for idx,var in enumerate(tfVars[0:total_vars//2]): | |
op_holder.append(tfVars[idx+total_vars//2].assign((var.value()*tau) + ((1-tau)*tfVars[idx+total_vars//2].value()))) | |
return op_holder | |
def updateTarget(op_holder,sess): | |
for op in op_holder: | |
sess.run(op) | |
total_vars = len(tf.trainable_variables()) | |
a = tf.trainable_variables()[0].eval(session=sess) | |
b = tf.trainable_variables()[total_vars//2].eval(session=sess) | |
if a.all() == b.all(): | |
print("Target Set Success") | |
else: | |
print("Target Set Failed") | |
#Record performance metrics and episode logs for the Control Center. | |
def saveToCenter(i,rList,jList,bufferArray,summaryLength,h_size,sess,mainQN,time_per_step): | |
with open('./Center/log.csv', 'a') as myfile: | |
state_display = (np.zeros([1,h_size]),np.zeros([1,h_size])) | |
imagesS = [] | |
for idx,z in enumerate(np.vstack(bufferArray[:,0])): | |
img,state_display = sess.run([mainQN.salience,mainQN.rnn_state],\ | |
feed_dict={mainQN.scalarInput:np.reshape(bufferArray[idx,0],[1,21168])/255.0,\ | |
mainQN.trainLength:1,mainQN.state_in:state_display,mainQN.batch_size:1}) | |
imagesS.append(img) | |
imagesS = (imagesS - np.min(imagesS))/(np.max(imagesS) - np.min(imagesS)) | |
imagesS = np.vstack(imagesS) | |
imagesS = np.resize(imagesS,[len(imagesS),84,84,3]) | |
luminance = np.max(imagesS,3) | |
imagesS = np.multiply(np.ones([len(imagesS),84,84,3]),np.reshape(luminance,[len(imagesS),84,84,1])) | |
make_gif(np.ones([len(imagesS),84,84,3]),'./Center/frames/sal'+str(i)+'.gif',duration=len(imagesS)*time_per_step,true_image=False,salience=True,salIMGS=luminance) | |
images = zip(bufferArray[:,0]) | |
images.append(bufferArray[-1,3]) | |
images = np.vstack(images) | |
images = np.resize(images,[len(images),84,84,3]) | |
make_gif(images,'./Center/frames/image'+str(i)+'.gif',duration=len(images)*time_per_step,true_image=True,salience=False) | |
wr = csv.writer(myfile, quoting=csv.QUOTE_ALL) | |
wr.writerow([i,np.mean(jList[-100:]),np.mean(rList[-summaryLength:]),'./frames/image'+str(i)+'.gif','./frames/log'+str(i)+'.csv','./frames/sal'+str(i)+'.gif']) | |
myfile.close() | |
with open('./Center/frames/log'+str(i)+'.csv','w') as myfile: | |
state_train = (np.zeros([1,h_size]),np.zeros([1,h_size])) | |
wr = csv.writer(myfile, quoting=csv.QUOTE_ALL) | |
wr.writerow(["ACTION","REWARD","A0","A1",'A2','A3','V']) | |
a, v = sess.run([mainQN.Advantage,mainQN.Value],\ | |
feed_dict={mainQN.scalarInput:np.vstack(bufferArray[:,0])/255.0,mainQN.trainLength:len(bufferArray),mainQN.state_in:state_train,mainQN.batch_size:1}) | |
wr.writerows(zip(bufferArray[:,1],bufferArray[:,2],a[:,0],a[:,1],a[:,2],a[:,3],v[:,0])) | |
#This code allows gifs to be saved of the training episode for use in the Control Center. | |
def make_gif(images, fname, duration=2, true_image=False,salience=False,salIMGS=None): | |
import moviepy.editor as mpy | |
def make_frame(t): | |
try: | |
x = images[int(len(images)/duration*t)] | |
except: | |
x = images[-1] | |
if true_image: | |
return x.astype(np.uint8) | |
else: | |
return ((x+1)/2*255).astype(np.uint8) | |
def make_mask(t): | |
try: | |
x = salIMGS[int(len(salIMGS)/duration*t)] | |
except: | |
x = salIMGS[-1] | |
return x | |
clip = mpy.VideoClip(make_frame, duration=duration) | |
if salience == True: | |
mask = mpy.VideoClip(make_mask, ismask=True,duration= duration) | |
clipB = clip.set_mask(mask) | |
clipB = clip.set_opacity(0) | |
mask = mask.set_opacity(0.1) | |
mask.write_gif(fname, fps = len(images) / duration,verbose=False) | |
#clipB.write_gif(fname, fps = len(images) / duration,verbose=False) | |
else: | |
clip.write_gif(fname, fps = len(images) / duration,verbose=False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment