Created
October 8, 2017 11:01
-
-
Save Shashi18/caf00ff77463904cd9e59e8058df47c9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
from pygame.locals import * | |
import numpy as np | |
import pygame as pg | |
import random | |
import time | |
start = time.time() | |
FPS = 150 # frames per second setting | |
fpsClock = pg.time.Clock() | |
pg.init() # pygame initialization | |
window = pg.display.set_mode((800, 600)) # width, height | |
pg.display.set_caption('Q learning Example!') | |
Qa = 0 | |
Left = 400 | |
Top = 570 | |
Width = 100 | |
Height = 20 | |
i = 0 | |
BLACK = (0,0,0) | |
WHITE = (255,255,255) | |
GREEN = (0,255,0) | |
rct = pg.Rect(Left, Top, Width, Height) # Rect(left, top, width, height) | |
action = 2 # 0 means stay, 1 means left, 2 means right | |
storage = {} #Dictionary | |
jumpY = 6 | |
jumpX = 8 | |
Q = np.zeros([25000, 3]) | |
cenX = 10 | |
cenY = 50 | |
radius = 10 | |
score = 0 | |
missed = 0 | |
reward = 0 | |
font = pg.font.Font(None, 30) | |
# set learning rate | |
lr = 0.7 | |
y = .5 | |
i = 0 | |
def calculate_score(rect, circle): | |
if rect.left <= circle.circleX <= rect.right: # if the circle'x x position is between the rectangles left and right | |
return 1 | |
else: | |
return -1 | |
def newXforCircle(radius): | |
newx = 100 - radius | |
multiplier = float(random.randint(1, 8)) # make more channel by making it a floating point number | |
newx *= multiplier | |
return newx | |
class State: | |
def __init__(self, rect, circle): | |
self.rect = rect | |
self.circle = circle | |
class Circle: | |
def __init__(self, circleX, circleY): | |
self.circleX = circleX | |
self.circleY = circleY | |
def convert(s): | |
y = int(s.circle.circleY) | |
x = int(s.circle.circleX) | |
z = int(s.rect.left) | |
n = float(str(x)+str(z)) | |
#print(str(x)+' '+str(y)+' '+str(x)+str(y)+str(z)+' '+str(n)) | |
if n in storage: | |
#print ('R '+str(n)) | |
return storage[n] | |
else: | |
if len(storage): | |
maximum = max(storage, key=storage.get) | |
storage[n] = storage[maximum] + 1 | |
else: | |
storage[n] = 1 | |
return storage[n] | |
def action(s): | |
return np.argmax(Q[convert(s), :]) | |
def afteraction(s, act): | |
rct = None | |
if act == 2: | |
if s.rect.right + 100 > 800: | |
rct = s.rect | |
else: | |
rct = pg.Rect(s.rect.left + 100, s.rect.top, s.rect.width, | |
s.rect.height) | |
elif act == 1: | |
if s.rect.left - 100 < 0: | |
rct = s.rect | |
else: | |
rct = pg.Rect(s.rect.left - 100, s.rect.top, s.rect.width, | |
s.rect.height) # Rect(left, top, width, height) | |
else: # action is 0, means stay where it is | |
rct = s.rect | |
X = s.circle.circleX + jumpX | |
Y = s.circle.circleY + jumpY | |
print (str(X)+' '+str(Y)) | |
newCircle = Circle(X, Y) | |
return State(rct, newCircle) | |
def newRect(rect, act): | |
if act == 2: | |
if rect.right + 100 > 800: | |
return rect | |
else: | |
return pg.Rect(rect.left + 100, rect.top, rect.width, rect.height) | |
elif act == 1: # action is left | |
if rect.left - 100 < 0: | |
return rect | |
else: | |
return pg.Rect(rect.left - 100, rect.top, rect.width, rect.height) | |
else: | |
return rect | |
while True: | |
for event in pg.event.get(): | |
if event.type == QUIT: | |
np.savetxt('test.txt', Q) | |
pg.quit() | |
sys.exit() | |
COL = [(255,255,0),(255,215,0),(238,221,130),(218,165,32),(184,134,11),(208,32,144),(238,130,238),(221,160,221),(218,112,214),(186,85,211),(153,50,204), | |
(148,0,211),(138,43,226),(173,255,47),(50,205,50),(154,205,50),(34,139,34),(107,142,35),(189,183,107),(240,230,140)] | |
window.fill((0,45,45)) | |
if cenY >= 590 - Height - radius: | |
reward = calculate_score(rct, Circle(cenX, cenY)) | |
if reward == -1: | |
cenX = newXforCircle(radius) | |
cenY = 50 | |
else: | |
Qa = COL[random.randint(0,19)] | |
jumpY *= -1 | |
cenY += jumpY | |
elif cenY < 50 and i!=0: | |
cenY += jumpY | |
jumpY = abs(jumpY) | |
else: | |
cenY+=jumpY | |
if cenX >= (800 - radius): | |
jumpX *= -1 | |
cenX += jumpX | |
elif cenX <= 2*radius and i!=0: | |
cenX += jumpX | |
jumpX = abs(jumpX) | |
else: | |
cenX += jumpX | |
print('b') | |
print (str(cenX)) | |
print('X: '+str(jumpX)+' Y: '+str(jumpY)) | |
s = State(rct, Circle(cenX, cenY)) | |
act = action(s) | |
r0 = calculate_score(s.rect, s.circle) | |
s1 = afteraction(s, act) | |
actx = action(s1) | |
Q[convert(s), act] += lr*(r0 + y * np.max(Q[convert(s1), :]) - Q[convert(s), act]) | |
rct = newRect(s.rect, act) | |
pg.draw.circle(window, Qa, (int(cenX),int(cenY)),radius) | |
pg.draw.rect(window, WHITE, rct) | |
if reward == 1: | |
score += reward | |
else: | |
missed += reward | |
reward = 0 | |
LR = '%.2f' % (abs(score+reward)/(1+abs(missed)+score)*100) | |
text = font.render('Score: ' + str(score), True, (243, 160, 90)) | |
text1 = font.render('Penalty: ' + str(missed), True, (125, 157, 207)) | |
text2 = font.render('LR :' + str(LR), True, (0, 255, 20)) | |
window.blit(text, (670, 10)) # render score | |
window.blit(text1, (10, 10)) # render missed | |
window.blit(text2, (320, 10)) | |
pg.display.update() # update display | |
fpsClock.tick(FPS) | |
i = 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment