Skip to content

Instantly share code, notes, and snippets.

@emakryo
Created March 16, 2017 02:30
Show Gist options
  • Save emakryo/76bf671254326581b7392728c90e92ee to your computer and use it in GitHub Desktop.
Save emakryo/76bf671254326581b7392728c90e92ee to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"from sklearn.model_selection import GridSearchCV\n",
"from sklearn.svm import SVC"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/kamesawa/anaconda/lib/python3.5/site-packages/ipykernel/__main__.py:3: VisibleDeprecationWarning: using a non-integer number instead of an integer will result in an error in the future\n",
" app.launch_new_instance()\n"
]
}
],
"source": [
"xdim = 800\n",
"n_train = 500\n",
"x = np.concatenate([np.random.randn(n_train/2,xdim)+1, np.random.randn(n_train/2, xdim)-1], axis=0)\n",
"y = np.array([1]*int(n_train//2)+[0]*int(n_train//2))"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"x_max = x.max(axis=0)\n",
"x_min = x.min(axis=0)\n",
"x_norm = (x-x_min)/(x_max-x_min)\n",
"mu = x_min\n",
"sigma = x_max-x_min"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"GridSearchCV(cv=10, error_score='raise',\n",
" estimator=SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,\n",
" decision_function_shape=None, degree=3, gamma='auto', kernel='rbf',\n",
" max_iter=-1, probability=False, random_state=None, shrinking=True,\n",
" tol=0.001, verbose=False),\n",
" fit_params={}, iid=True, n_jobs=-1,\n",
" param_grid={'C': [0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1, 2, 5, 10]},\n",
" pre_dispatch='2*n_jobs', refit=True, return_train_score=True,\n",
" scoring=None, verbose=0)"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"params = {'C':[0.001,0.002,0.005,0.01,0.02,0.05,0.1,0.2,0.5,1,2,5,10]}\n",
"model = GridSearchCV(SVC(), param_grid=params, n_jobs=-1, cv=10)\n",
"model.fit(x_norm, y)"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"b = model.best_estimator_.intercept_"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"x_sv = model.best_estimator_.support_vectors_\n",
"y_sv = y[model.best_estimator_.support_]"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"(500, 800)"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x_sv.shape"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"(500,)"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_sv.shape"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"alpha = model.best_estimator_.dual_coef_.reshape(-1)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"(500,)"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"alpha.shape"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"K = x_sv.shape[1]"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"f=model.decision_function(x_norm)"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"array([ 0.01552757, 0.01571668, 0.01595049, 0.01581598, 0.01627937,\n",
" 0.01703046, 0.01561769, 0.0169352 , 0.01623666, 0.01602281,\n",
" 0.01615129, 0.01579446, 0.01551655, 0.01556035, 0.01606687,\n",
" 0.01618227, 0.01527452, 0.0163952 , 0.01555727, 0.01576296,\n",
" 0.01554271, 0.01569798, 0.01587299, 0.01611354, 0.01637843,\n",
" 0.01605714, 0.01588729, 0.0167204 , 0.0167552 , 0.01576557,\n",
" 0.01547629, 0.01619597, 0.01575721, 0.01598879, 0.01615045,\n",
" 0.01543679, 0.0158725 , 0.01592349, 0.0166146 , 0.01692708,\n",
" 0.0160071 , 0.01607074, 0.01619291, 0.01642718, 0.01577377,\n",
" 0.01636337, 0.01574164, 0.01568793, 0.01575194, 0.01652984,\n",
" 0.01636036, 0.01643591, 0.01505776, 0.01566353, 0.0154124 ,\n",
" 0.01604286, 0.01590392, 0.01586717, 0.0158375 , 0.0163316 ,\n",
" 0.0162211 , 0.01693921, 0.01547554, 0.01629274, 0.01592897,\n",
" 0.0151992 , 0.01593563, 0.01522626, 0.01500247, 0.01620056,\n",
" 0.01730532, 0.01527014, 0.01583539, 0.01573449, 0.01713084,\n",
" 0.01565802, 0.01656045, 0.01588769, 0.01591121, 0.01597334,\n",
" 0.0153078 , 0.01589612, 0.01572989, 0.01751353, 0.01596109,\n",
" 0.01598724, 0.01571308, 0.01599855, 0.01599539, 0.01668124,\n",
" 0.01615154, 0.01565651, 0.01611631, 0.0154329 , 0.0152102 ,\n",
" 0.01692863, 0.01638792, 0.01658287, 0.0162798 , 0.01660174,\n",
" 0.01686835, 0.0165601 , 0.01585643, 0.01541782, 0.01637111,\n",
" 0.01585504, 0.01567514, 0.01555099, 0.01640746, 0.01544726,\n",
" 0.01568762, 0.01553166, 0.01679312, 0.0160351 , 0.01663253,\n",
" 0.01625985, 0.01596983, 0.01685134, 0.01649984, 0.0158042 ,\n",
" 0.01623872, 0.0165725 , 0.01715394, 0.01595861, 0.01635084,\n",
" 0.01528581, 0.01638236, 0.01529774, 0.01687316, 0.01671698,\n",
" 0.01640004, 0.01650177, 0.0158463 , 0.01648825, 0.01569995,\n",
" 0.01554805, 0.01533571, 0.01522645, 0.01624614, 0.01522982,\n",
" 0.01728246, 0.01560427, 0.01688771, 0.01592793, 0.01571693,\n",
" 0.0151167 , 0.01574038, 0.01624723, 0.01707745, 0.01551976,\n",
" 0.0168809 , 0.01607377, 0.01537635, 0.0154371 , 0.01617223,\n",
" 0.01659982, 0.01614931, 0.01586348, 0.01482871, 0.01558615,\n",
" 0.01567178, 0.01560996, 0.01537563, 0.01661969, 0.01542046,\n",
" 0.0156956 , 0.01640684, 0.01541509, 0.01605846, 0.01599176,\n",
" 0.01602473, 0.01682078, 0.01533705, 0.01519273, 0.01652767,\n",
" 0.01602966, 0.01672556, 0.01617836, 0.01652322, 0.01684524,\n",
" 0.01532476, 0.015365 , 0.0160783 , 0.01723004, 0.01686299,\n",
" 0.0165253 , 0.01508503, 0.01697242, 0.01639624, 0.01654723,\n",
" 0.0163585 , 0.0163397 , 0.01651702, 0.01577379, 0.01612335,\n",
" 0.01659565, 0.01654669, 0.01578632, 0.01624858, 0.01685635,\n",
" 0.01617587, 0.01682411, 0.01644416, 0.0171514 , 0.01607298,\n",
" 0.01503609, 0.01709963, 0.01722842, 0.0171831 , 0.01610578,\n",
" 0.01541864, 0.01598682, 0.01559442, 0.01606839, 0.01584689,\n",
" 0.01505466, 0.01640088, 0.01626739, 0.01536937, 0.01580608,\n",
" 0.01618052, 0.01660318, 0.01533734, 0.01633451, 0.01678055,\n",
" 0.01580602, 0.01586231, 0.01732177, 0.01670348, 0.01495673,\n",
" 0.0167025 , 0.01597741, 0.01679479, 0.01665623, 0.01576164,\n",
" 0.01559897, 0.01555329, 0.01638894, 0.01658571, 0.01563451,\n",
" 0.01637101, 0.01578801, 0.01603876, 0.0157081 , 0.01793709,\n",
" 0.01710282, 0.01701101, 0.01681855, 0.01621585, 0.01612295,\n",
" -0.01703401, -0.01709795, -0.01622029, -0.01626999, -0.01648597,\n",
" -0.01681639, -0.01599649, -0.01680019, -0.01689922, -0.01667915,\n",
" -0.01679957, -0.01618307, -0.01682189, -0.01587456, -0.01697981,\n",
" -0.01600485, -0.01686336, -0.01550216, -0.01614516, -0.0164086 ,\n",
" -0.01697391, -0.01612676, -0.01674134, -0.01676343, -0.01603981,\n",
" -0.01714612, -0.01659035, -0.01583592, -0.0160616 , -0.01622484,\n",
" -0.01657912, -0.01649557, -0.01527514, -0.01672725, -0.01597239,\n",
" -0.01671237, -0.01715654, -0.01652857, -0.01579148, -0.01793708,\n",
" -0.01660222, -0.01780073, -0.01663708, -0.01630271, -0.01618462,\n",
" -0.01655053, -0.01705057, -0.0169329 , -0.01591331, -0.0167205 ,\n",
" -0.01619538, -0.01656762, -0.01664989, -0.01579884, -0.01723561,\n",
" -0.01664028, -0.01645496, -0.01661222, -0.01598262, -0.0169346 ,\n",
" -0.01575346, -0.01758978, -0.01485469, -0.01721255, -0.01701431,\n",
" -0.01659989, -0.01667717, -0.01633015, -0.01717662, -0.01643125,\n",
" -0.01678993, -0.0158331 , -0.01627893, -0.01736244, -0.01580045,\n",
" -0.01629771, -0.01698846, -0.01656482, -0.01604505, -0.01632908,\n",
" -0.01757353, -0.01672145, -0.01694657, -0.01645888, -0.01718137,\n",
" -0.0161049 , -0.01669098, -0.01682937, -0.01592013, -0.01713521,\n",
" -0.01576299, -0.01609153, -0.01634235, -0.01675608, -0.01655492,\n",
" -0.01675758, -0.01660653, -0.01709869, -0.0165165 , -0.01689097,\n",
" -0.01632281, -0.01560315, -0.01593644, -0.01698907, -0.01667951,\n",
" -0.01571251, -0.01677162, -0.01698601, -0.01722842, -0.01684804,\n",
" -0.01611217, -0.01656518, -0.01717414, -0.01722908, -0.01724332,\n",
" -0.01710977, -0.0171357 , -0.01739764, -0.01635358, -0.01754864,\n",
" -0.01630502, -0.01642344, -0.01665187, -0.01671024, -0.01756348,\n",
" -0.01681321, -0.01682566, -0.01580046, -0.01756506, -0.0169066 ,\n",
" -0.01712386, -0.01595992, -0.01743429, -0.01686807, -0.01678928,\n",
" -0.01629118, -0.01687875, -0.01611504, -0.0147916 , -0.01643723,\n",
" -0.01656829, -0.01649629, -0.01524977, -0.01612791, -0.01640436,\n",
" -0.01755115, -0.01667322, -0.01672206, -0.01597012, -0.01784512,\n",
" -0.01592531, -0.01559185, -0.01689318, -0.01632649, -0.01721665,\n",
" -0.01581702, -0.01623407, -0.01644811, -0.01698585, -0.01757863,\n",
" -0.0163756 , -0.01730538, -0.01701834, -0.01726106, -0.01706198,\n",
" -0.0163839 , -0.01703111, -0.01580113, -0.01680904, -0.01576434,\n",
" -0.01708327, -0.01630237, -0.01608863, -0.01678519, -0.01697296,\n",
" -0.01608913, -0.01639268, -0.01619124, -0.01659549, -0.0162813 ,\n",
" -0.01687114, -0.01714171, -0.01693718, -0.01591222, -0.01701593,\n",
" -0.01644609, -0.01592791, -0.01713659, -0.0164086 , -0.01697853,\n",
" -0.01624464, -0.01605633, -0.01726067, -0.01641983, -0.01663538,\n",
" -0.01706349, -0.0173464 , -0.01698731, -0.01626416, -0.01642684,\n",
" -0.01647518, -0.01684861, -0.01645277, -0.01638215, -0.01656029,\n",
" -0.01592836, -0.01562799, -0.0170543 , -0.01646514, -0.01659519,\n",
" -0.01699086, -0.0166951 , -0.01641675, -0.01660269, -0.01653392,\n",
" -0.01634079, -0.01699089, -0.01639789, -0.01594253, -0.01611346,\n",
" -0.01610669, -0.01645956, -0.01718978, -0.01705044, -0.0166536 ,\n",
" -0.01661651, -0.01644612, -0.01667532, -0.01702294, -0.01686379,\n",
" -0.01602036, -0.01585131, -0.01682668, -0.01690539, -0.01642952,\n",
" -0.01662413, -0.01781256, -0.01556481, -0.01691123, -0.0170073 ,\n",
" -0.01637877, -0.01629665, -0.01686744, -0.01623165, -0.01596116,\n",
" -0.01642159, -0.01675238, -0.01723511, -0.01646484, -0.01551829])"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"f_ = b+np.sum(alpha*y_sv*np.exp(np.sum((x_norm.reshape(-1,1,800)-x_sv.reshape(1,-1,800))**2/K, axis=2)), axis=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment