Last active
August 29, 2015 14:21
-
-
Save alfredplpl/1012bdc6e650e1eb083a to your computer and use it in GitHub Desktop.
ミニバッチSGDです。Kaggle見ても、作れとしか書かれてなかったので作りました。
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
# This code is distributed under the 3-Clause BSD license (New BSD license). | |
# 基本的に作者の名前を書いていただければ、商用利用も可能です。なお、保証はしません。 | |
# 参考URL: http://osdn.jp/projects/opensource/wiki/licenses%2Fnew_BSD_license | |
from sklearn import linear_model | |
import Image | |
import numpy as np | |
from sklearn.cross_validation import ShuffleSplit | |
class MinibatchSGDRegressor(linear_model.SGDRegressor): | |
def __init__(self, batch_size=32, n_iter=5,alpha=0.0001,penalty="l2",verbose=0): | |
linear_model.SGDRegressor.__init__(self, | |
n_iter=n_iter,alpha=alpha,penalty=penalty, | |
verbose=verbose) | |
self.batch_size=batch_size | |
def fit(self,X,y,n_iter=20,random_state=0): | |
rs=ShuffleSplit(X.shape[0], n_iter=n_iter,random_state=random_state, | |
test_size=1.0/self.batch_size) | |
for dummy, batch in rs: | |
self.partial_fit(X=X[batch],y=y[batch]) | |
#The following code doesn't run if you import this. | |
#importしても実行されないので無視してimportしてください | |
if __name__ == "__main__": | |
from sklearn.datasets import make_regression | |
from sklearn.cross_validation import cross_val_score | |
X,y=make_regression(n_samples=100000,n_features=20,n_targets=1,noise=10) | |
clf = MinibatchSGDRegressor(batch_size=4) | |
clfBaseline = linear_model.LinearRegression() | |
print cross_val_score(clf,X,y,cv=2) | |
print cross_val_score(clfBaseline,X,y,cv=2) | |
# Copyright (c) 2015, alfredplpl | |
# All rights reserved. | |
__author__ = 'alfredplpl' | |
# References: https://github.com/lisa-lab/pylearn2/blob/master/pylearn2/training_algorithms/sgd.py | |
# https://www.kaggle.com/c/criteo-display-ad-challenge/forums/t/9561/how-to-apply-python-linear-model-sgdregressor-to-do-logistic-regression | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Kaggle見ても、作れとしか書かれてなかったので作りました。