Created
February 10, 2021 16:54
-
-
Save bquast/08b04acfa742a19210e4b7705549de5e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# logsumexp | |
logsumexp <- function (x) { | |
y = max(x) | |
y + log(sum(exp(x - y))) | |
} | |
# softmax | |
softmax <- function (x) { | |
exp(x - logsumexp(x)) | |
} | |
# create inputs and outputs | |
x = matrix(sample(1:10, 18, replace=TRUE), nrow=3) | |
y = x*2 | |
# create weights | |
w <- matrix(data=0,nrow=6,ncol=6) | |
# predictor matrix | |
yhat = matrix(0, nrow=3, ncol=6) | |
# products | |
for (i in 1:ncol(x)) { | |
for (j in 1:ncol(x)) { | |
w[j,i] <- x[,j] %*% x[,i] | |
} | |
} | |
# softmax for each row | |
for (i in 1:nrow(w)) { | |
w[i,] <- softmax(w[i,]) | |
} | |
# temp | |
temp=matrix(data=0,ncol=6,nrow=3) | |
# multiply each input vector by corresponding weight | |
for (i in 1:ncol(x)) { | |
for (j in 1:ncol(x)) { | |
temp[,j] <- x[,j] * w[i,j] | |
} | |
yhat[,i] = rowSums(temp) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment