Created
February 7, 2017 16:48
-
-
Save abisz/45108a0177781f7790169797800ce176 to your computer and use it in GitHub Desktop.
This is a basic example of machine learning. The algorithm is a perceptron learning algorithm. The data set to learn is artificially created.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# creating an artificial target function | |
# in a real machine learning application this would, of course, be unkown | |
target <- function(x) { | |
if (x[1] + x[2] > -5) { | |
return(-1) | |
} else { | |
return(1) | |
} | |
} | |
# size of data set | |
N <- 100 | |
# creating dataset with random coords and precomputed y | |
D <- data.frame() | |
for (i in 1:N) { | |
D[i,1] <- sample(-10:10, 1) | |
D[i,2] <- sample(-10:10, 1) | |
D[i,"y"] <- target(D[i,]) | |
} | |
# plot the data set, color indicated the two categories | |
plot(D[c(1,2)], col = ifelse(D$y < 0, 'red', 'green'), pch=19) | |
# initializing random weights | |
w <- c(runif(3)) | |
# this is our function that should approximate the target function | |
# key to solving is finding the right weigths | |
hypothesis <- function(x) { | |
input <- c(1, unlist(x)) | |
if (input[1] * w[1] + input[2] * w[2] + input[3] * w[3] > 0) { | |
return(1) | |
} else { | |
return(-1) | |
} | |
} | |
# for counting the necessary iterations to find the weights | |
counter <- 1 | |
# while there are misclassified data points... | |
while (nrow(D[apply(D, c(1), hypothesis) != D['y'],]) > 0) { | |
# retrieve all elements that are misclassified with the current hypothesis | |
m <- D[apply(D, c(1), hypothesis) != D['y'],] | |
m.length <- nrow(m) | |
# choose a random misclassified data point | |
random <- sample(1:m.length, 1) | |
e <- m[random,] | |
# adjust the weigths to match the random missmatch | |
w <- w + e$y * c(1, as.numeric(e[1:2])) | |
# plot current hypothesis curve | |
k <- -(w[2]/w[3]) | |
d <- -(w[1]/w[3]) | |
curve(x*k+d, add=TRUE) | |
# for transparency | |
Sys.sleep(0.5) | |
print(paste('iteration:', counter)) | |
counter <- counter + 1 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment