Created
March 14, 2019 15:31
-
-
Save daviwesley/d99ba56b8685faeb3b5c07b7523ce8ab to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class QLearning(): | |
def __init__(self): | |
self.self.matrizQ = [[0]*5 for x in range(5)] | |
self.learningRate = 0.9 | |
self.discountRate = 0 | |
self.indiceAcaoAtual = 2 | |
self.indiceEstadoAtual = 0 | |
def calculaFatores(): | |
self.learningRate -= 0.1 | |
self.discountRate += 0.1 | |
def calculaMelhorAcao(): | |
vetor = self.matrizQ[self.indiceEstadoAtual] | |
valorDaAcaoAtual = vetor[0] | |
indiceDaAcao = 0 | |
indiceAcaoEscolhida = 0 | |
for acao in vetor : | |
if acao > valorDaAcaoAtual : | |
indiceAcaoEscolhida = indiceDaAcao | |
valorDaAcaoAtual = acao | |
indiceDaAcao += 1 | |
#atualizando variável global de ação | |
self.indiceAcaoAtual = indiceAcaoEscolhida | |
#atualizando variável global de estado | |
def politicaDeAcoes(): | |
self.indiceEstadoAtual += 1 | |
if self.indiceEstadoAtual >= len(self.matrizQ): | |
self.indiceEstadoAtual = 0 | |
calculaMelhorAcao() | |
calculaFatores() | |
def maxQProximoEstado(proximoEstado): | |
proximoEstado += 1 | |
if proximoEstado > len(self.matrizQ) : | |
proximoEstado = 0 | |
vetor = self.matrizQ[proximoEstado] | |
maiorRecompensa = max(vetor[0]) | |
return maiorRecompensa | |
def calculaNovoValorQ(reward): | |
print(self.matrizQ) | |
qAtual = self.matrizQ[self.indiceEstadoAtual][self.indiceAcaoAtual] | |
novoValorQ = qAtual + self.learningRate*(reward + self.discountRate * maxQProximoEstado(self.indiceEstadoAtual) - qAtual) | |
self.matrizQ = novoValorQ | |
politicaDeAcoes() | |
def exibeMensagemDeAcaoAtual(): | |
if self.indiceAcaoAtual == 0 : | |
print("Acao muito fácil") | |
elif self.indiceAcaoAtual == 1: | |
print("Acao fácil") | |
elif self.indiceAcaoAtual == 2: | |
print("Acao intermediária") | |
elif self.indiceAcaoAtual == 3: | |
print("Acao difícil") | |
elif self.indiceAcaoAtual == 4: | |
print("Acao muito difícil") | |
else: | |
print("valor inválido") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment