Skip to content

Instantly share code, notes, and snippets.

# BrambleXu/polynomial_logistic_regression.py Last active Jun 16, 2019

 import numpy as np import matplotlib.pyplot as plt # read data data = np.loadtxt("non_linear_data.csv", delimiter=',', skiprows=1) train_x = data[:, 0:2] train_y = data[:, 2] # plot data points # plt.plot(train_x[train_y == 1, 0], train_x[train_y == 1, 1], 'o') # plt.plot(train_x[train_y == 0, 0], train_x[train_y == 0, 1], 'x') # plt.show() # initialize parameter theta = np.random.randn(4) # standardization mu = train_x.mean(axis=0) sigma = train_x.std(axis=0) def standardizer(x): return (x - mu) / sigma std_x = standardizer(train_x) # add x0 and x3 to get matrix def to_matrix(x): x0 = np.ones([x.shape, 1]) # (20, 1) x3 = x[:, 0, np.newaxis] ** 2 # (20, 1) return np.hstack([x0, x, x3]) mat_x = to_matrix(std_x) # (20, 4) # sigmoid function def f(x): """ theta: (4,) x: (n, 4) return sigmoid(x) -> (4, 1) """ return 1 / (1 + np.exp(-np.dot(x, theta))) # classify sample to 0 or 1 def classify(x): return (f(x) >= 0.5).astype(np.int) # update times epoch = 2000 # learning rate ETA = 1e-3 # accuracy log accuracies = [] # update parameter for _ in range(epoch): """ f(mat_x) - train_y: (20,) mat_x: (20, 4) theta: (4,) dot production: (20,) x (20, 4) -> (4,) """ theta = theta - ETA * np.dot(f(mat_x) - train_y, mat_x) result = classify(mat_x) == train_y # result is [Ture, False, ...] accuracy = sum(result) / len(result) accuracies.append(accuracy) ## plot line # x1 = np.linspace(-2, 2, 100) # x2 = - (theta + x1 * theta + theta * x1**2) / theta # plt.plot(std_x[train_y == 1, 0], std_x[train_y == 1, 1], 'o') # train data of class 1 # plt.plot(std_x[train_y == 0, 0], std_x[train_y == 0, 1], 'x') # train data of class 0 # plt.plot(x1, x2, linestyle='dashed') # plot the line we learned # plt.show() # plot accuracy line x = np.arange(len(accuracies)) plt.plot(x, accuracies) plt.show()
to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.