Last active
February 8, 2018 09:42
-
-
Save cjue25/cacf551a05d07d86855c292f87a31270 to your computer and use it in GitHub Desktop.
Using Python for Research_Case Study 7
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import statsmodels.api as sm | |
mod=sm.OLS(y, x) #ordinary least sqares | |
est=mod.fit() | |
print (est.summary()) | |
#這裡得出來的斜率較大是因為沒有截距,從0開始畫 | |
#可以print 出詳細資訊 | |
X=sm.add_constant(x) | |
mod=sm.OLS(y, X) | |
est=mod.fit() | |
print (est.summary()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from mpl_toolkits.mplot3d import Axes3D | |
fig=plt.figure() | |
ax=fig.add_subplot(111,projection='3d') | |
ax.scatter(X[:,0],X[:,1],y,c=y) | |
ax.set_xlabel("$x_1$") | |
ax.set_ylabel("$x_2$") | |
ax.set_zlabel("$y$"); | |
plt.show() ####3D圖 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.linear_model import LinearRegression | |
lm=LinearRegression(fit_intercept=True) | |
lm.fit(X,y) | |
lm.intercept_ #截距 | |
lm.coef_[0] #beta_1 | |
lm.coef_[1] #beta_2 | |
#test | |
X_0=np.array([2,4]) | |
lm.predict(X_0.reshape(1,-1)) | |
print (lm.score(X,y)) #自動比對準確率 | |
from sklearn.model_selection import train_test_split | |
X_train, X_test, y_train, y_test=train_test_split(X,y,test_size=0.5,random_state=1) | |
lm=LinearRegression(fit_intercept=True) | |
lm.fit(X_train, y_train) | |
print (lm.score(X_test,y_test)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import scipy.stats as ss | |
import matplotlib.pyplot as plt | |
from sklearn.linear_model import LinearRegression | |
from sklearn.model_selection import train_test_split | |
h=1 | |
sd=1 | |
n=50 | |
def gen_data(n,h,sd1,sd2): | |
x1=ss.norm.rvs(-h,sd1,n) | |
y1=ss.norm.rvs(0,sd1,n) | |
x2=ss.norm.rvs(h,sd2,n) | |
y2=ss.norm.rvs(0,sd2,n) | |
return (x1,y1,x2,y2) | |
(x1,y1,x2,y2)=gen_data(1000,1.5,1,1.5) | |
def plot_data(x1,y1,x2,y2): | |
plt.figure() | |
plt.plot(x1,y1,'o',ms=2) | |
plt.plot(x2,y2,'o',ms=2) | |
plt.xlabel("$X_1$") | |
plt.ylabel("$X_2$") | |
plot_data(x1,y1,x2,y2) | |
plt.savefig("gen_data") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.linear_model import LogisticRegression | |
clf=LogisticRegression() | |
X=np.vstack((np.vstack((x1,y1)).T, np.vstack((x2,y2)).T)) | |
n=1000 | |
y=np.hstack((np.repeat(1,n), np.repeat(2,n))) | |
X_train, X_test, y_train, y_test=train_test_split(X,y,test_size=0.5,random_state=1) | |
clf.fit(X_train,y_train) | |
print (clf.score(X_test,y_test)) | |
print (clf.predict_proba(np.array([-2,0]).reshape(1,-1))) #每個分類的機率 | |
print (clf.predict(np.array([-2,0]).reshape(1,-1))) | |
def plot_probs(ax,clf,class_no): | |
xx1,xx2=np.meshgrid(np.arange(-5,5,0.1),np.arange(-5,5,0.1)) | |
probs=clf.predict_proba(np.stack((xx1.ravel(),xx2.ravel()),axis=1)) | |
Z=probs[:,class_no] | |
Z=Z.reshape(xx1.shape) | |
CS=ax.contourf(xx1,xx2,Z) | |
cbar=plt.colorbar(CS) | |
plt.xlabel("$X_1$") | |
plt.ylabel("$X_2$") | |
plt.figure(figsize=(10,4)) | |
ax=plt.subplot(121) | |
plot_probs(ax,clf,0) | |
plt.title("Pred. prob for class 1") | |
ax=plt.subplot(122) | |
plot_probs(ax,clf,1) | |
plt.title("Pred. prob for class 2") | |
plt.savefig("classification.png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment