Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
import numpy as np | |
import random | |
import itertools | |
import scipy.misc | |
import matplotlib.pyplot as plt | |
class gameOb(): | |
def __init__(self,coordinates,size,intensity,channel,reward,name): | |
self.x = coordinates[0] | |
self.y = coordinates[1] | |
self.size = size | |
self.intensity = intensity | |
self.channel = channel | |
self.reward = reward | |
self.name = name | |
class gameEnv(): | |
def __init__(self,partial,size): | |
self.sizeX = size | |
self.sizeY = size | |
self.actions = 4 | |
self.objects = [] | |
self.partial = partial | |
a = self.reset() | |
plt.imshow(a,interpolation="nearest") | |
def reset(self): | |
self.objects = [] | |
hero = gameOb(self.newPosition(),1,1,2,None,'hero') | |
self.objects.append(hero) | |
bug = gameOb(self.newPosition(),1,1,1,1,'goal') | |
self.objects.append(bug) | |
hole = gameOb(self.newPosition(),1,1,0,-1,'fire') | |
self.objects.append(hole) | |
bug2 = gameOb(self.newPosition(),1,1,1,1,'goal') | |
self.objects.append(bug2) | |
hole2 = gameOb(self.newPosition(),1,1,0,-1,'fire') | |
self.objects.append(hole2) | |
bug3 = gameOb(self.newPosition(),1,1,1,1,'goal') | |
self.objects.append(bug3) | |
bug4 = gameOb(self.newPosition(),1,1,1,1,'goal') | |
self.objects.append(bug4) | |
state = self.renderEnv() | |
self.state = state | |
return state | |
def moveChar(self,direction): | |
# 0 - up, 1 - down, 2 - left, 3 - right | |
hero = self.objects[0] | |
heroX = hero.x | |
heroY = hero.y | |
penalize = 0. | |
if direction == 0 and hero.y >= 1: | |
hero.y -= 1 | |
if direction == 1 and hero.y <= self.sizeY-2: | |
hero.y += 1 | |
if direction == 2 and hero.x >= 1: | |
hero.x -= 1 | |
if direction == 3 and hero.x <= self.sizeX-2: | |
hero.x += 1 | |
if hero.x == heroX and hero.y == heroY: | |
penalize = 0.0 | |
self.objects[0] = hero | |
return penalize | |
def newPosition(self): | |
iterables = [ range(self.sizeX), range(self.sizeY)] | |
points = [] | |
for t in itertools.product(*iterables): | |
points.append(t) | |
currentPositions = [] | |
for objectA in self.objects: | |
if (objectA.x,objectA.y) not in currentPositions: | |
currentPositions.append((objectA.x,objectA.y)) | |
for pos in currentPositions: | |
points.remove(pos) | |
location = np.random.choice(range(len(points)),replace=False) | |
return points[location] | |
def checkGoal(self): | |
others = [] | |
for obj in self.objects: | |
if obj.name == 'hero': | |
hero = obj | |
else: | |
others.append(obj) | |
ended = False | |
for other in others: | |
if hero.x == other.x and hero.y == other.y: | |
self.objects.remove(other) | |
if other.reward == 1: | |
self.objects.append(gameOb(self.newPosition(),1,1,1,1,'goal')) | |
else: | |
self.objects.append(gameOb(self.newPosition(),1,1,0,-1,'fire')) | |
return other.reward,False | |
if ended == False: | |
return 0.0,False | |
def renderEnv(self): | |
#a = np.zeros([self.sizeY,self.sizeX,3]) | |
a = np.ones([self.sizeY+2,self.sizeX+2,3]) | |
a[1:-1,1:-1,:] = 0 | |
hero = None | |
for item in self.objects: | |
a[item.y+1:item.y+item.size+1,item.x+1:item.x+item.size+1,item.channel] = item.intensity | |
if item.name == 'hero': | |
hero = item | |
if self.partial == True: | |
a = a[hero.y:hero.y+3,hero.x:hero.x+3,:] | |
b = scipy.misc.imresize(a[:,:,0],[84,84,1],interp='nearest') | |
c = scipy.misc.imresize(a[:,:,1],[84,84,1],interp='nearest') | |
d = scipy.misc.imresize(a[:,:,2],[84,84,1],interp='nearest') | |
a = np.stack([b,c,d],axis=2) | |
return a | |
def step(self,action): | |
penalty = self.moveChar(action) | |
reward,done = self.checkGoal() | |
state = self.renderEnv() | |
if reward == None: | |
print(done) | |
print(reward) | |
print(penalty) | |
return state,(reward+penalty),done | |
else: | |
return state,(reward+penalty),done |
import numpy as np | |
import random | |
import tensorflow as tf | |
import matplotlib.pyplot as plt | |
import scipy.misc | |
import os | |
import csv | |
import itertools | |
import tensorflow.contrib.slim as slim | |
#This is a simple function to reshape our game frames. | |
def processState(state1): | |
return np.reshape(state1,[21168]) | |
#These functions allows us to update the parameters of our target network with those of the primary network. | |
def updateTargetGraph(tfVars,tau): | |
total_vars = len(tfVars) | |
op_holder = [] | |
for idx,var in enumerate(tfVars[0:total_vars//2]): | |
op_holder.append(tfVars[idx+total_vars//2].assign((var.value()*tau) + ((1-tau)*tfVars[idx+total_vars//2].value()))) | |
return op_holder | |
def updateTarget(op_holder,sess): | |
for op in op_holder: | |
sess.run(op) | |
total_vars = len(tf.trainable_variables()) | |
a = tf.trainable_variables()[0].eval(session=sess) | |
b = tf.trainable_variables()[total_vars//2].eval(session=sess) | |
if a.all() == b.all(): | |
print("Target Set Success") | |
else: | |
print("Target Set Failed") | |
#Record performance metrics and episode logs for the Control Center. | |
def saveToCenter(i,rList,jList,bufferArray,summaryLength,h_size,sess,mainQN,time_per_step): | |
with open('./Center/log.csv', 'a') as myfile: | |
state_display = (np.zeros([1,h_size]),np.zeros([1,h_size])) | |
imagesS = [] | |
for idx,z in enumerate(np.vstack(bufferArray[:,0])): | |
img,state_display = sess.run([mainQN.salience,mainQN.rnn_state],\ | |
feed_dict={mainQN.scalarInput:np.reshape(bufferArray[idx,0],[1,21168])/255.0,\ | |
mainQN.trainLength:1,mainQN.state_in:state_display,mainQN.batch_size:1}) | |
imagesS.append(img) | |
imagesS = (imagesS - np.min(imagesS))/(np.max(imagesS) - np.min(imagesS)) | |
imagesS = np.vstack(imagesS) | |
imagesS = np.resize(imagesS,[len(imagesS),84,84,3]) | |
luminance = np.max(imagesS,3) | |
imagesS = np.multiply(np.ones([len(imagesS),84,84,3]),np.reshape(luminance,[len(imagesS),84,84,1])) | |
make_gif(np.ones([len(imagesS),84,84,3]),'./Center/frames/sal'+str(i)+'.gif',duration=len(imagesS)*time_per_step,true_image=False,salience=True,salIMGS=luminance) | |
images = zip(bufferArray[:,0]) | |
images.append(bufferArray[-1,3]) | |
images = np.vstack(images) | |
images = np.resize(images,[len(images),84,84,3]) | |
make_gif(images,'./Center/frames/image'+str(i)+'.gif',duration=len(images)*time_per_step,true_image=True,salience=False) | |
wr = csv.writer(myfile, quoting=csv.QUOTE_ALL) | |
wr.writerow([i,np.mean(jList[-100:]),np.mean(rList[-summaryLength:]),'./frames/image'+str(i)+'.gif','./frames/log'+str(i)+'.csv','./frames/sal'+str(i)+'.gif']) | |
myfile.close() | |
with open('./Center/frames/log'+str(i)+'.csv','w') as myfile: | |
state_train = (np.zeros([1,h_size]),np.zeros([1,h_size])) | |
wr = csv.writer(myfile, quoting=csv.QUOTE_ALL) | |
wr.writerow(["ACTION","REWARD","A0","A1",'A2','A3','V']) | |
a, v = sess.run([mainQN.Advantage,mainQN.Value],\ | |
feed_dict={mainQN.scalarInput:np.vstack(bufferArray[:,0])/255.0,mainQN.trainLength:len(bufferArray),mainQN.state_in:state_train,mainQN.batch_size:1}) | |
wr.writerows(zip(bufferArray[:,1],bufferArray[:,2],a[:,0],a[:,1],a[:,2],a[:,3],v[:,0])) | |
#This code allows gifs to be saved of the training episode for use in the Control Center. | |
def make_gif(images, fname, duration=2, true_image=False,salience=False,salIMGS=None): | |
import moviepy.editor as mpy | |
def make_frame(t): | |
try: | |
x = images[int(len(images)/duration*t)] | |
except: | |
x = images[-1] | |
if true_image: | |
return x.astype(np.uint8) | |
else: | |
return ((x+1)/2*255).astype(np.uint8) | |
def make_mask(t): | |
try: | |
x = salIMGS[int(len(salIMGS)/duration*t)] | |
except: | |
x = salIMGS[-1] | |
return x | |
clip = mpy.VideoClip(make_frame, duration=duration) | |
if salience == True: | |
mask = mpy.VideoClip(make_mask, ismask=True,duration= duration) | |
clipB = clip.set_mask(mask) | |
clipB = clip.set_opacity(0) | |
mask = mask.set_opacity(0.1) | |
mask.write_gif(fname, fps = len(images) / duration,verbose=False) | |
#clipB.write_gif(fname, fps = len(images) / duration,verbose=False) | |
else: | |
clip.write_gif(fname, fps = len(images) / duration,verbose=False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment