Created
March 23, 2016 10:44
-
-
Save clizarralde/81540cb7f7ad587e88d8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Created on 18/9/2015 | |
@author: Charly | |
''' | |
from promiedos import PartidoModel, Session, PosicionModel | |
from sqlalchemy.sql.expression import desc, or_, and_ | |
from sqlalchemy.sql.functions import func | |
import logging | |
from math import sqrt | |
from sklearn.ensemble.forest import RandomForestClassifier | |
from sklearn.cross_validation import cross_val_score | |
from sklearn import svm, linear_model, cross_validation | |
import urllib2 | |
from scrapy.http.response.html import HtmlResponse | |
from sklearn.linear_model.stochastic_gradient import SGDClassifier | |
from collections import defaultdict | |
import operator | |
import random | |
import csv | |
GRANDES = ['Independiente' , 'Racing Club', 'Boca Juniors', 'River Plate', 'San Lorenzo'] | |
WEIGHTS = [0.25, 0.25, 0.3, 0.2] | |
class Predictor: | |
def __init__(self,weights=WEIGHTS): | |
self.weights = weights | |
self.daysOfHistory = 8 | |
self.session = Session() | |
self.logger = logging.getLogger() | |
def getGoles(self, equipo, local, cant_partidos ,anio=None, fecha=None): | |
''' | |
Saca la cantidad media de goles de un equipo para una determinada fecha o partido. | |
''' | |
matches = self.getUltimosPartidos(equipo, local, cant_partidos, anio, fecha) | |
avg = 0.0 | |
for m in matches: | |
if local: | |
avg+=m.gol1 | |
else: | |
avg+=m.gol2 | |
avg = avg/len(matches) | |
return avg | |
def getUltimosPartidos(self, equipo, local, cant_partidos ,anio=None, fecha=None): | |
''' | |
Devuelve la lista de los ultimos N partidos de un equipo como local o visitante. | |
''' | |
if equipo == 'Gimnasia LP': | |
equipo = 'Gimnasia (LP)' | |
if equipo == 'Estudiantes LP': | |
equipo = 'Estudiantes (LP)' | |
query = self.session.query(PartidoModel) | |
if local: | |
query = query.filter(PartidoModel.equipo1==equipo) | |
else: | |
query = query.filter(PartidoModel.equipo2==equipo) | |
if anio: | |
query = query.filter(PartidoModel.anio==anio) | |
if fecha: | |
query = query.filter(PartidoModel.nfecha<fecha) | |
matches = query.order_by(desc(PartidoModel.anio),desc(PartidoModel.nfecha)).limit(cant_partidos).all() | |
return matches | |
def getPromedio(self, equipo, local, cant_partidos ,anio=None, fecha=None): | |
''' | |
Devuelve el promedio de la suma de los goles de los ultimos partidos antes de la fecha y anio | |
''' | |
matches = self.getUltimosPartidos(equipo, local, cant_partidos, anio, fecha) | |
if len(matches)==0: | |
print "Warning " | |
return 2.5 | |
avg = 0.0 | |
for m in matches: | |
avg+=m.gol1+m.gol2 | |
avg = avg/len(matches) | |
return avg | |
def getResultadoHistorico(self, anio, fecha, local, visitante): | |
f1 = or_(and_(PartidoModel.anio==anio,PartidoModel.nfecha<fecha),or_(PartidoModel.anio<anio)) | |
avg = self.session.query(func.avg(PartidoModel.gol1+PartidoModel.gol2)).filter(f1).filter(PartidoModel.equipo1==local,PartidoModel.equipo2==visitante).order_by(desc(PartidoModel.anio),desc(PartidoModel.nfecha)) | |
result = avg.first()[0] | |
if not result: | |
return 2.5 | |
return float(result) | |
def normalize(self,feature): | |
if not feature: | |
return 0 | |
if feature > 2.5: | |
return 1 | |
else: | |
return 0 | |
def getPosicion(self, equipo, anio, fecha): | |
posicion = self.session.query(PosicionModel).filter(PosicionModel.anio==anio,PosicionModel.nfecha==fecha,PosicionModel.equipo==equipo).first() | |
if posicion: | |
return posicion.posicion | |
else: | |
return 0 | |
def features(self,local,visitante, anio=None,fecha=None): | |
grande_local, grande_visitante = 0 , 0 | |
if local in GRANDES: | |
grande_local = 1 | |
if visitante in GRANDES: | |
grande_visitante = 1 | |
f = [ | |
grande_local, | |
grande_visitante, | |
anio, | |
fecha, | |
self.getResultadoHistorico(anio,fecha,local,visitante), | |
self.getPromedio(local,True,2,anio,fecha), | |
self.getPromedio(local,True,4,anio,fecha), | |
self.getPromedio(local,True,8,anio,fecha), | |
self.getPromedio(visitante,False,2,anio,fecha), | |
self.getPromedio(visitante,False,4,anio,fecha), | |
self.getPromedio(visitante,False,8,anio,fecha), | |
self.getGoles(local,True,2,anio,fecha), | |
self.getGoles(local,True,4,anio,fecha), | |
self.getGoles(local,True,8,anio,fecha), | |
self.getGoles(visitante,False,2,anio,fecha), | |
self.getGoles(visitante,False,4,anio,fecha), | |
self.getGoles(visitante,False,8,anio,fecha), | |
self.getPosicion(local,anio,fecha), | |
self.getPosicion(visitante,anio,fecha), | |
# diferencia entre las posiciones de local y visitante | |
abs ( self.getPosicion(local,anio,fecha) - self.getPosicion(visitante,anio,fecha)), | |
] | |
if local == 'Independiente': | |
print f | |
return f | |
def predictWithScikit(self, predictor, local,visitante, anio=None,fecha=None): | |
feat = self.features(local, visitante, anio, fecha) | |
output = predictor.predict_proba(feat) | |
#print output | |
return output[0] | |
def predictWithCriollo(self,local,visitante, anio=None,fecha=None): | |
''' | |
predice el resultado de un partido | |
''' | |
w = self.weights | |
historico = self.getResultadoHistorico(anio,fecha,local,visitante) | |
if not historico: | |
historico = 2.5 | |
motivacion_local = self.getPromedio(local,True,self.daysOfHistory,anio,fecha) | |
if not motivacion_local: | |
motivacion_local = 2.5 | |
motivacion_visitante = self.getPromedio(visitante,False,self.daysOfHistory,anio,fecha) | |
if not motivacion_visitante: | |
motivacion_visitante = 2.5 | |
goles_local = self.getGoles(local, True, self.daysOfHistory, anio, fecha) | |
goles_visitante = self.getGoles(visitante, False, self.daysOfHistory, anio, fecha) | |
goles_probables = goles_local+goles_visitante | |
result = historico*w[0] + motivacion_local*w[1] + motivacion_visitante*w[2] + goles_probables*w[3] | |
return result | |
def test(self,anio,fecha,weights=None, days=4): | |
if weights: | |
self.weights = weights | |
if days: | |
self.daysOfHistory = days | |
partidos = self.session.query(PartidoModel).filter(PartidoModel.anio==anio, PartidoModel.nfecha>=fecha) | |
''' | |
partidos.filter(or_(PartidoModel.equipo1=='Gimnasia (LP)',PartidoModel.equipo2=='Gimnasia (LP)', | |
PartidoModel.equipo2=='Sarmiento (J)',PartidoModel.equipo1=='Sarmiento (J)'), | |
PartidoModel.equipo2=='Union',PartidoModel.equipo1=='Union', | |
PartidoModel.equipo2=='Independiente',PartidoModel.equipo1=='Independiente') | |
''' | |
partidos = partidos.all() | |
#partidos = self.session.query(PartidoModel).filter(PartidoModel.equipo1=='River Plate',PartidoModel.anio==anio, PartidoModel.nfecha>=fecha).all() | |
rmse = 0 | |
oks = 0 | |
bads = 0 | |
histogram = defaultdict(lambda:0) | |
for p in partidos: | |
predicted = self.predict(p.equipo1, p.equipo2, p.anio, p.nfecha) | |
real = p.gol1+p.gol2 | |
diff = abs(predicted-real) | |
rmse += diff*diff | |
ok = (predicted > 2.5 and real > 2.5 ) or ((predicted < 2.5 and real < 2.5 ) ) | |
if ok: | |
oks+=1 | |
histogram[p.equipo1]+=1 | |
histogram[p.equipo2]+=1 | |
else: | |
bads+=1 | |
histogram[p.equipo1]-=1 | |
histogram[p.equipo2]-=1 | |
#print "%s vs. %s %f\t Real %f Diff %f\t OK %s" % ( p.equipo1, p.equipo2, predicted, real , diff , ok ) | |
rmse = sqrt(rmse/(oks+bads)) | |
print "Total RMSE=%f\t OKs=%i BAD=%i PRECISION=%f" % ( rmse, oks, bads, oks*1.0/(oks+bads)*1.0) | |
sorted_x = sorted(histogram.items(), key=operator.itemgetter(1),reverse=True) | |
#print sorted_x[:4] | |
#print "-"*10 | |
#print sorted_x[-4:] | |
def trainDataSet(self, anio=2013): | |
''' | |
Crea el dataset de training con partidos a partir de un anio. | |
''' | |
partidos = self.session.query(PartidoModel).filter(PartidoModel.anio>=anio, PartidoModel.nfecha>3).all() | |
x = [] | |
y = [] | |
for p in partidos: | |
features = self.features(p.equipo1, p.equipo2, p.anio, p.nfecha) | |
x.append(features) | |
out = 0 | |
if (p.gol1+p.gol2)>2.5: | |
out = 1 | |
y.append(out) | |
return x , y | |
def testAll(self): | |
base_weight = [.2,.25,.25,.3] | |
all_weights = [ | |
] | |
for i in range(6): | |
copy = list(base_weight) | |
random.shuffle(copy) | |
all_weights.append(copy) | |
#years = [2015,2014,2013] | |
years = [2015] | |
days = [2,4,8,6] | |
#days = [8] | |
for w in all_weights: | |
print "---" | |
print "weights " + str(w) | |
for y in years: | |
for d in days: | |
print "year= %i days=%d" % ( y,d) | |
self.test(y,10,w,d) | |
def trainScikitPredictor(self): | |
X, y = self.trainDataSet(anio=2015) | |
#writer = csv.writer('features.csv') | |
print "Train dataset build. Has %i samples" % len(X) | |
#clf = RandomForestClassifier(n_estimators=50, max_depth=None, min_samples_split=5, random_state=0) | |
clf = svm.SVC(kernel='rbf', C=1) | |
#clf = SGDClassifier() | |
X_train, X_test, y_train, y_test = cross_validation.train_test_split(X, y, test_size=0.1, random_state=0) | |
clf = clf.fit(X_train, y_train) | |
print clf.score(X_test, y_test) | |
return clf | |
def predictNextFecha(self,scikit=True): | |
predictor = None | |
if scikit: | |
predictor = self.trainScikitPredictor() | |
url = "http://www.promiedos.com.ar/primera" | |
response = urllib2.urlopen(url) | |
response = HtmlResponse(url=url, body=response.read()) | |
m1 = response.xpath('//span[@class="datoequipo"]/text()') | |
for i in range(0,len(m1),2): | |
equipo1 = m1[i].extract().strip() | |
equipo2 = m1[i+1].extract().strip() | |
if scikit: | |
probs = self.predictWithScikit(predictor, equipo1, equipo2, 2015, 25 ) | |
print "Partido: %s vs. %s. Probs=> Under 2.5=%f Over 2.5=%f" % ( equipo1, equipo2, probs[0],probs[1]) | |
else: | |
print "Partido: %s vs. %s. Goles=> %f" % ( equipo1, equipo2, self.predictWithCriollo(equipo1, equipo2, 2015, 25)) | |
if __name__ == '__main__': | |
logging.basicConfig() | |
pr = Predictor() | |
pr.trainScikitPredictor() | |
#pr.predictNextFecha(scikit=True) | |
#pr.testAll() | |
#print pr.getPromedio('Racing (C)') | |
#print pr.predict('Independiente', 'Racing (C)') | |
#print avg | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment