Skip to content

Instantly share code, notes, and snippets.

@hnagata
Last active August 29, 2015 14:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hnagata/549b68a1b6e2a1060c5e to your computer and use it in GitHub Desktop.
Save hnagata/549b68a1b6e2a1060c5e to your computer and use it in GitHub Desktop.
## grid / ggplot ----
library(ggplot2)
library(grid)
make.grid <- function(row, col) {
grid.newpage()
l <- grid.layout(row, col)
v <- viewport(layout=l)
pushViewport(v)
}
print.at <- function(o, i, j) {
print(o, vp=viewport(layout.pos.row=i, layout.pos.col=j))
}
end.grid <- function() {
popViewport()
}
## データ読み込み ----
dat <- read.csv("user.csv", fileEncoding="utf-8")
elemname <- c("chartInterval", "chartStability", "chartExpressiveness", "chartVibratoLongtone", "chartRhythm")
p <- length(elemname)
lv.user <- levels(dat$user)
lv.reqno <- levels(dat$requestNo)
n.user <- length(lv.user)
n.reqno <- length(lv.reqno)
# (曲, ユーザー) ペアで総合点が最も高いものだけ使う
filt.idx <- tapply(1 : nrow(dat), factor(paste0(dat$requestNo, dat$user)), function(idx) {
sub.dat <- dat[idx, ]
idx[which.max(dat[idx, "totalPoint"])]
})
dat <- dat[filt.idx, ]
# pitch を数値に変換
replace <- function(reptb, x) {
for (i in 1 : nrow(reptb)) x <- gsub(reptb[i, 1], reptb[i, 2], x); x
}
dat$highPitch <- gsub("♭", "b", dat$highPitch)
dat$lowPitch <- gsub("♭", "b", dat$lowPitch)
reptb.pitch <- data.frame(
c("lowAb", "lowA", "lowBb", "lowB", "lowC", "lowDb", "lowD", "lowEb", "lowE", "~lowF", "lowGb", "lowG",
"m1Ab", "m1A", "m1Bb", "m1B", "m1C", "m1Db", "m1D", "m1Eb", "m1E", "m1F", "m1Gb", "m1G",
"m2Ab", "m2A", "m2Bb", "m2B", "m2C", "m2Db", "m2D", "m2Eb", "m2E", "m2F", "m2Gb", "m2G",
"hihiAb", "hihiA", "hihiBb", "hihiB~", "hihiC", "hihiDb", "hihiD", "hihiEb", "hihiE", "hihiF", "hihiGb", "hihiG",
"hiAb", "hiA", "hiBb", "hiB", "hiC", "hiDb", "hiD", "hiEb", "hiE", "hiF", "hiGb", "hiG",
"Ab", "A", "Bb", "B", "C", "Db", "D", "Eb", "E", "F", "Gb", "G"
),
c(32 : 43, 44 : 55, 56 : 67, 80 : 91, 68 : 79, 68 : 79)
)
dat$highPitch <- as.numeric(replace(reptb.pitch, dat$highPitch))
dat$lowPitch <- as.numeric(replace(reptb.pitch, dat$lowPitch))
# 曲テーブルを作成
songs <- data.frame(
reqno=lv.reqno,
artist=factor(tapply(as.character(dat$artist), dat$requestNo, function(x) x[1])),
contents=factor(tapply(as.character(dat$contents), dat$requestNo, function(x) x[1])),
highPitch=tapply(dat$highPitch, dat$requestNo, function(x) x[1]),
lowPitch=tapply(dat$lowPitch, dat$requestNo, function(x) x[1])
)
songs$diffPitch <- songs$highPitch - songs$lowPitch + 1
## 訓練・テストデータを作る ----
# 全体の 5% をテストに回す
set.seed(0)
index.test <- sample(1 : nrow(dat), nrow(dat) * 0.05)
dat.test <- dat[index.test, ]
dat.train <- dat[-index.test, ]
# 訓練データにない曲を使うテストデータをはじく(要検討)
dat.test <- dat.test[table(dat.train$requestNo)[dat.test$requestNo] > 0, ]
# Temporary variables
y <- as.matrix(dat.train[, elemname])
user.test <- as.numeric(dat.test$user)
reqno.test <- as.numeric(dat.test$requestNo)
y.test <- as.matrix(dat.test[, elemname])
# データ数のチェック
c(train=nrow(dat.train), test=nrow(dat.test))
## ベースライン: ユーザー内平均を予測値とする ----
mse <- function(true.y, pred.y) {
sum((pred.y - true.y)^2) / length(true.y)
}
pred.y.by.mean <- apply(y, 2, function(y) tapply(y, dat.train$user, mean))[user.test, ]
mse(y.test, pred.y.by.mean)
## diaglm 実装 ----
library(parallel)
diaglm <- function(dat, threshold=4, weighted=FALSE, verbose=TRUE, cl=NULL) {
if (!is.null(cl)) sapply <- function(...) parSapply(cl, ...)
# 訓練には threshold 回以上出現する曲だけ使う
dat <- dat[table(dat$requestNo)[dat$requestNo] >= threshold, ]
reqno <- as.numeric(dat$requestNo)
user <- as.numeric(dat$user)
y <- as.matrix(dat[, elemname])
# 初期値の設定
a <- matrix(1, n.reqno, p)
x <- apply(y, 2, function(y) tapply(y, dat$user, mean))
if (weighted) {
w <- apply(y, 2, function(y) table(y)[as.character(y)])
} else {
w <- matrix(1, nrow(dat), 5)
}
# 交互最適化
iter <- 0
err <- Inf
while (err > 1e-07) {
iter <- iter + 1
a0 <- a
x0 <- x
j <- which(table(dat$requestNo) > 0)
a[j, ] <- t(sapply(j, function(j, y, xx, w, reqno, user) {
sub.y <- y[reqno == j, , drop=FALSE]
sub.x <- xx[user[reqno == j], , drop=FALSE]
sub.w <- w[reqno == j, , drop=FALSE]
sub.wx <- sqrt(sub.w) * sub.x
diag((t(sub.wx) %*% sub.y) / (t(sub.wx) %*% sub.x))
}, y=y, xx=x, w=w, reqno=reqno, user=user))
x <- t(sapply(1 : n.user, function(i, y, a, w, reqno, user) {
sub.y <- y[user == i, , drop=FALSE]
sub.a <- a[reqno[user == i], , drop=FALSE]
sub.w <- w[user == i, , drop=FALSE]
sub.wa <- sqrt(sub.w) * sub.a
diag((t(sub.wa) %*% sub.y) / (t(sub.wa) %*% sub.a))
}, y=y, a=a, w=w, reqno=reqno, user=user))
err <- sum((a - a0)^2) / (n.reqno * p) + sum(((x - x0) * 0.01)^2) / (n.user * p)
if (verbose) {
cat(paste0("#", iter, ": ", round(err, digits=8), "\n"))
}
}
resid <- a[reqno, ] * x[user, ] - y
r2 <- sapply(1 : p, function(k) {
(var(y[, k]) - sum(resid[, k]^2) / nrow(dat)) / var(y[, k])
})
colnames(a) <- elemname
colnames(x) <- elemname
list(a=a, x=x, resid=resid, r2=r2)
}
## 推定 ----
cl <- makeCluster(4)
# diaglm
diaglm.std <- diaglm(dat.train, threshold=4, weighted=FALSE, cl=cl)
pred.y.by.diaglm.std <- diaglm.std$a[reqno.test, ] * diaglm.std$x[user.test, ]
c(mse=mse(y.test, pred.y.by.diaglm.std), r2=diaglm.std$r2)
# weighted diaglm
diaglm.w <- diaglm(dat.train, threshold=0, weighted=TRUE, cl=cl)
pred.y.by.diaglm.w <- diaglm.w$a[reqno.test, ] * diaglm.w$x[user.test, ]
c(mse=mse(y.test, pred.y.by.diaglm.w), r2=diaglm.w$r2)
# threshold
summary.diaglm.t <- t(sapply(1 : 6, function(threshold) {
diaglm.t <- diaglm(dat.train, threshold=threshold)
pred <- diaglm.t$a[reqno.test, ] * diaglm.t$x[user.test, ]
c(mse=mse(y.test, pred), r2=diaglm.t$r2)
}))
summary.diaglm.t
stopCluster(cl)
## Interval のみで mean, diaglm を比較
int.df <- data.frame(
true = y.test[, "chartInterval"],
pred.mean = pred.y.by.mean[, "chartInterval"],
pred.diaglm = pred.y.by.diaglm.std[, "chartInterval"]
)
g1 <- ggplot(data=int.df, aes(x=true, y=pred.diaglm)) +
geom_point() +
geom_abline(slope=1) +
xlim(50, 100) + ylim(50, 100) +
xlab("True int") + ylab("Predicted int (proposed)")
g2 <- ggplot(data=int.df, aes(x=true, y=pred.mean)) +
geom_point() +
geom_abline(slope=1) +
xlim(50, 100) + ylim(50, 100) +
xlab("True int") + ylab("Predicted int (baseline)")
svg("interval.svg", width=8, height=4)
make.grid(1, 2)
print.at(g1, 1, 1)
print.at(g2, 1, 2)
end.grid()
dev.off()
## 残差の大きいサンプルを見る ----
int.df$resid.diaglm <- int.df$pred.diaglm - int.df$true
int.df$resid.mean <- int.df$pred.mean - int.df$true
int.df <- int.df[order(abs(int.df$resid.diaglm), decreasing=TRUE), ]
int.df[1:20, c("true", "resid.diaglm", "resid.mean")]
## 音程の取りづらい曲は? ----
cl <- makeCluster(4)
diaglm.full <- diaglm(dat, threshold=4, weighted=FALSE, cl=cl)
stopCluster(cl)
df <- data.frame(
a=diaglm.full$a[, 1],
songs[, c("diffPitch", "contents", "artist")]
)
df[order(df$a), ][1 : 10, ]
df[order(df$a, decreasing=TRUE), ][1 : 10, ]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment