Skip to content

Instantly share code, notes, and snippets.

@yoheitaonishi
Last active November 5, 2018 13:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yoheitaonishi/4f530100fc822137194c480e76325734 to your computer and use it in GitHub Desktop.
Save yoheitaonishi/4f530100fc822137194c480e76325734 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
import numpy as np
import random
import itertools
import scipy.misc
import matplotlib.pyplot as plt
class gameOb():
def __init__(self,coordinates,size,intensity,channel,reward,name):
self.x = coordinates[0]
self.y = coordinates[1]
self.size = size
self.intensity = intensity
self.channel = channel
self.reward = reward
self.name = name
class gameEnv():
def __init__(self,partial,size):
self.sizeX = size
self.sizeY = size
self.actions = 4
self.objects = []
self.partial = partial
a = self.reset()
plt.imshow(a,interpolation="nearest")
def reset(self):
self.objects = []
hero = gameOb(self.newPosition(),1,1,2,None,'hero')
self.objects.append(hero)
bug = gameOb(self.newPosition(),1,1,1,1,'goal')
self.objects.append(bug)
hole = gameOb(self.newPosition(),1,1,0,-1,'fire')
self.objects.append(hole)
bug2 = gameOb(self.newPosition(),1,1,1,1,'goal')
self.objects.append(bug2)
hole2 = gameOb(self.newPosition(),1,1,0,-1,'fire')
self.objects.append(hole2)
bug3 = gameOb(self.newPosition(),1,1,1,1,'goal')
self.objects.append(bug3)
bug4 = gameOb(self.newPosition(),1,1,1,1,'goal')
self.objects.append(bug4)
state = self.renderEnv()
self.state = state
return state
def moveChar(self,direction):
# 0 - up, 1 - down, 2 - left, 3 - right
hero = self.objects[0]
heroX = hero.x
heroY = hero.y
penalize = 0.
if direction == 0 and hero.y >= 1:
hero.y -= 1
if direction == 1 and hero.y <= self.sizeY-2:
hero.y += 1
if direction == 2 and hero.x >= 1:
hero.x -= 1
if direction == 3 and hero.x <= self.sizeX-2:
hero.x += 1
if hero.x == heroX and hero.y == heroY:
penalize = 0.0
self.objects[0] = hero
return penalize
def newPosition(self):
iterables = [ range(self.sizeX), range(self.sizeY)]
points = []
for t in itertools.product(*iterables):
points.append(t)
currentPositions = []
for objectA in self.objects:
if (objectA.x,objectA.y) not in currentPositions:
currentPositions.append((objectA.x,objectA.y))
for pos in currentPositions:
points.remove(pos)
location = np.random.choice(range(len(points)),replace=False)
return points[location]
def checkGoal(self):
others = []
for obj in self.objects:
if obj.name == 'hero':
hero = obj
else:
others.append(obj)
ended = False
for other in others:
if hero.x == other.x and hero.y == other.y:
self.objects.remove(other)
if other.reward == 1:
self.objects.append(gameOb(self.newPosition(),1,1,1,1,'goal'))
else:
self.objects.append(gameOb(self.newPosition(),1,1,0,-1,'fire'))
return other.reward,False
if ended == False:
return 0.0,False
def renderEnv(self):
#a = np.zeros([self.sizeY,self.sizeX,3])
a = np.ones([self.sizeY+2,self.sizeX+2,3])
a[1:-1,1:-1,:] = 0
hero = None
for item in self.objects:
a[item.y+1:item.y+item.size+1,item.x+1:item.x+item.size+1,item.channel] = item.intensity
if item.name == 'hero':
hero = item
if self.partial == True:
a = a[hero.y:hero.y+3,hero.x:hero.x+3,:]
b = scipy.misc.imresize(a[:,:,0],[84,84,1],interp='nearest')
c = scipy.misc.imresize(a[:,:,1],[84,84,1],interp='nearest')
d = scipy.misc.imresize(a[:,:,2],[84,84,1],interp='nearest')
a = np.stack([b,c,d],axis=2)
return a
def step(self,action):
penalty = self.moveChar(action)
reward,done = self.checkGoal()
state = self.renderEnv()
if reward == None:
print(done)
print(reward)
print(penalty)
return state,(reward+penalty),done
else:
return state,(reward+penalty),done
import numpy as np
import random
import tensorflow as tf
import matplotlib.pyplot as plt
import scipy.misc
import os
import csv
import itertools
import tensorflow.contrib.slim as slim
#This is a simple function to reshape our game frames.
def processState(state1):
return np.reshape(state1,[21168])
#These functions allows us to update the parameters of our target network with those of the primary network.
def updateTargetGraph(tfVars,tau):
total_vars = len(tfVars)
op_holder = []
for idx,var in enumerate(tfVars[0:total_vars//2]):
op_holder.append(tfVars[idx+total_vars//2].assign((var.value()*tau) + ((1-tau)*tfVars[idx+total_vars//2].value())))
return op_holder
def updateTarget(op_holder,sess):
for op in op_holder:
sess.run(op)
total_vars = len(tf.trainable_variables())
a = tf.trainable_variables()[0].eval(session=sess)
b = tf.trainable_variables()[total_vars//2].eval(session=sess)
if a.all() == b.all():
print("Target Set Success")
else:
print("Target Set Failed")
#Record performance metrics and episode logs for the Control Center.
def saveToCenter(i,rList,jList,bufferArray,summaryLength,h_size,sess,mainQN,time_per_step):
with open('./Center/log.csv', 'a') as myfile:
state_display = (np.zeros([1,h_size]),np.zeros([1,h_size]))
imagesS = []
for idx,z in enumerate(np.vstack(bufferArray[:,0])):
img,state_display = sess.run([mainQN.salience,mainQN.rnn_state],\
feed_dict={mainQN.scalarInput:np.reshape(bufferArray[idx,0],[1,21168])/255.0,\
mainQN.trainLength:1,mainQN.state_in:state_display,mainQN.batch_size:1})
imagesS.append(img)
imagesS = (imagesS - np.min(imagesS))/(np.max(imagesS) - np.min(imagesS))
imagesS = np.vstack(imagesS)
imagesS = np.resize(imagesS,[len(imagesS),84,84,3])
luminance = np.max(imagesS,3)
imagesS = np.multiply(np.ones([len(imagesS),84,84,3]),np.reshape(luminance,[len(imagesS),84,84,1]))
make_gif(np.ones([len(imagesS),84,84,3]),'./Center/frames/sal'+str(i)+'.gif',duration=len(imagesS)*time_per_step,true_image=False,salience=True,salIMGS=luminance)
images = zip(bufferArray[:,0])
images.append(bufferArray[-1,3])
images = np.vstack(images)
images = np.resize(images,[len(images),84,84,3])
make_gif(images,'./Center/frames/image'+str(i)+'.gif',duration=len(images)*time_per_step,true_image=True,salience=False)
wr = csv.writer(myfile, quoting=csv.QUOTE_ALL)
wr.writerow([i,np.mean(jList[-100:]),np.mean(rList[-summaryLength:]),'./frames/image'+str(i)+'.gif','./frames/log'+str(i)+'.csv','./frames/sal'+str(i)+'.gif'])
myfile.close()
with open('./Center/frames/log'+str(i)+'.csv','w') as myfile:
state_train = (np.zeros([1,h_size]),np.zeros([1,h_size]))
wr = csv.writer(myfile, quoting=csv.QUOTE_ALL)
wr.writerow(["ACTION","REWARD","A0","A1",'A2','A3','V'])
a, v = sess.run([mainQN.Advantage,mainQN.Value],\
feed_dict={mainQN.scalarInput:np.vstack(bufferArray[:,0])/255.0,mainQN.trainLength:len(bufferArray),mainQN.state_in:state_train,mainQN.batch_size:1})
wr.writerows(zip(bufferArray[:,1],bufferArray[:,2],a[:,0],a[:,1],a[:,2],a[:,3],v[:,0]))
#This code allows gifs to be saved of the training episode for use in the Control Center.
def make_gif(images, fname, duration=2, true_image=False,salience=False,salIMGS=None):
import moviepy.editor as mpy
def make_frame(t):
try:
x = images[int(len(images)/duration*t)]
except:
x = images[-1]
if true_image:
return x.astype(np.uint8)
else:
return ((x+1)/2*255).astype(np.uint8)
def make_mask(t):
try:
x = salIMGS[int(len(salIMGS)/duration*t)]
except:
x = salIMGS[-1]
return x
clip = mpy.VideoClip(make_frame, duration=duration)
if salience == True:
mask = mpy.VideoClip(make_mask, ismask=True,duration= duration)
clipB = clip.set_mask(mask)
clipB = clip.set_opacity(0)
mask = mask.set_opacity(0.1)
mask.write_gif(fname, fps = len(images) / duration,verbose=False)
#clipB.write_gif(fname, fps = len(images) / duration,verbose=False)
else:
clip.write_gif(fname, fps = len(images) / duration,verbose=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment