Last active
August 29, 2015 14:21
-
-
Save hnagata/549b68a1b6e2a1060c5e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ## grid / ggplot ---- | |
| library(ggplot2) | |
| library(grid) | |
| make.grid <- function(row, col) { | |
| grid.newpage() | |
| l <- grid.layout(row, col) | |
| v <- viewport(layout=l) | |
| pushViewport(v) | |
| } | |
| print.at <- function(o, i, j) { | |
| print(o, vp=viewport(layout.pos.row=i, layout.pos.col=j)) | |
| } | |
| end.grid <- function() { | |
| popViewport() | |
| } | |
| ## データ読み込み ---- | |
| dat <- read.csv("user.csv", fileEncoding="utf-8") | |
| elemname <- c("chartInterval", "chartStability", "chartExpressiveness", "chartVibratoLongtone", "chartRhythm") | |
| p <- length(elemname) | |
| lv.user <- levels(dat$user) | |
| lv.reqno <- levels(dat$requestNo) | |
| n.user <- length(lv.user) | |
| n.reqno <- length(lv.reqno) | |
| # (曲, ユーザー) ペアで総合点が最も高いものだけ使う | |
| filt.idx <- tapply(1 : nrow(dat), factor(paste0(dat$requestNo, dat$user)), function(idx) { | |
| sub.dat <- dat[idx, ] | |
| idx[which.max(dat[idx, "totalPoint"])] | |
| }) | |
| dat <- dat[filt.idx, ] | |
| # pitch を数値に変換 | |
| replace <- function(reptb, x) { | |
| for (i in 1 : nrow(reptb)) x <- gsub(reptb[i, 1], reptb[i, 2], x); x | |
| } | |
| dat$highPitch <- gsub("♭", "b", dat$highPitch) | |
| dat$lowPitch <- gsub("♭", "b", dat$lowPitch) | |
| reptb.pitch <- data.frame( | |
| c("lowAb", "lowA", "lowBb", "lowB", "lowC", "lowDb", "lowD", "lowEb", "lowE", "~lowF", "lowGb", "lowG", | |
| "m1Ab", "m1A", "m1Bb", "m1B", "m1C", "m1Db", "m1D", "m1Eb", "m1E", "m1F", "m1Gb", "m1G", | |
| "m2Ab", "m2A", "m2Bb", "m2B", "m2C", "m2Db", "m2D", "m2Eb", "m2E", "m2F", "m2Gb", "m2G", | |
| "hihiAb", "hihiA", "hihiBb", "hihiB~", "hihiC", "hihiDb", "hihiD", "hihiEb", "hihiE", "hihiF", "hihiGb", "hihiG", | |
| "hiAb", "hiA", "hiBb", "hiB", "hiC", "hiDb", "hiD", "hiEb", "hiE", "hiF", "hiGb", "hiG", | |
| "Ab", "A", "Bb", "B", "C", "Db", "D", "Eb", "E", "F", "Gb", "G" | |
| ), | |
| c(32 : 43, 44 : 55, 56 : 67, 80 : 91, 68 : 79, 68 : 79) | |
| ) | |
| dat$highPitch <- as.numeric(replace(reptb.pitch, dat$highPitch)) | |
| dat$lowPitch <- as.numeric(replace(reptb.pitch, dat$lowPitch)) | |
| # 曲テーブルを作成 | |
| songs <- data.frame( | |
| reqno=lv.reqno, | |
| artist=factor(tapply(as.character(dat$artist), dat$requestNo, function(x) x[1])), | |
| contents=factor(tapply(as.character(dat$contents), dat$requestNo, function(x) x[1])), | |
| highPitch=tapply(dat$highPitch, dat$requestNo, function(x) x[1]), | |
| lowPitch=tapply(dat$lowPitch, dat$requestNo, function(x) x[1]) | |
| ) | |
| songs$diffPitch <- songs$highPitch - songs$lowPitch + 1 | |
| ## 訓練・テストデータを作る ---- | |
| # 全体の 5% をテストに回す | |
| set.seed(0) | |
| index.test <- sample(1 : nrow(dat), nrow(dat) * 0.05) | |
| dat.test <- dat[index.test, ] | |
| dat.train <- dat[-index.test, ] | |
| # 訓練データにない曲を使うテストデータをはじく(要検討) | |
| dat.test <- dat.test[table(dat.train$requestNo)[dat.test$requestNo] > 0, ] | |
| # Temporary variables | |
| y <- as.matrix(dat.train[, elemname]) | |
| user.test <- as.numeric(dat.test$user) | |
| reqno.test <- as.numeric(dat.test$requestNo) | |
| y.test <- as.matrix(dat.test[, elemname]) | |
| # データ数のチェック | |
| c(train=nrow(dat.train), test=nrow(dat.test)) | |
| ## ベースライン: ユーザー内平均を予測値とする ---- | |
| mse <- function(true.y, pred.y) { | |
| sum((pred.y - true.y)^2) / length(true.y) | |
| } | |
| pred.y.by.mean <- apply(y, 2, function(y) tapply(y, dat.train$user, mean))[user.test, ] | |
| mse(y.test, pred.y.by.mean) | |
| ## diaglm 実装 ---- | |
| library(parallel) | |
| diaglm <- function(dat, threshold=4, weighted=FALSE, verbose=TRUE, cl=NULL) { | |
| if (!is.null(cl)) sapply <- function(...) parSapply(cl, ...) | |
| # 訓練には threshold 回以上出現する曲だけ使う | |
| dat <- dat[table(dat$requestNo)[dat$requestNo] >= threshold, ] | |
| reqno <- as.numeric(dat$requestNo) | |
| user <- as.numeric(dat$user) | |
| y <- as.matrix(dat[, elemname]) | |
| # 初期値の設定 | |
| a <- matrix(1, n.reqno, p) | |
| x <- apply(y, 2, function(y) tapply(y, dat$user, mean)) | |
| if (weighted) { | |
| w <- apply(y, 2, function(y) table(y)[as.character(y)]) | |
| } else { | |
| w <- matrix(1, nrow(dat), 5) | |
| } | |
| # 交互最適化 | |
| iter <- 0 | |
| err <- Inf | |
| while (err > 1e-07) { | |
| iter <- iter + 1 | |
| a0 <- a | |
| x0 <- x | |
| j <- which(table(dat$requestNo) > 0) | |
| a[j, ] <- t(sapply(j, function(j, y, xx, w, reqno, user) { | |
| sub.y <- y[reqno == j, , drop=FALSE] | |
| sub.x <- xx[user[reqno == j], , drop=FALSE] | |
| sub.w <- w[reqno == j, , drop=FALSE] | |
| sub.wx <- sqrt(sub.w) * sub.x | |
| diag((t(sub.wx) %*% sub.y) / (t(sub.wx) %*% sub.x)) | |
| }, y=y, xx=x, w=w, reqno=reqno, user=user)) | |
| x <- t(sapply(1 : n.user, function(i, y, a, w, reqno, user) { | |
| sub.y <- y[user == i, , drop=FALSE] | |
| sub.a <- a[reqno[user == i], , drop=FALSE] | |
| sub.w <- w[user == i, , drop=FALSE] | |
| sub.wa <- sqrt(sub.w) * sub.a | |
| diag((t(sub.wa) %*% sub.y) / (t(sub.wa) %*% sub.a)) | |
| }, y=y, a=a, w=w, reqno=reqno, user=user)) | |
| err <- sum((a - a0)^2) / (n.reqno * p) + sum(((x - x0) * 0.01)^2) / (n.user * p) | |
| if (verbose) { | |
| cat(paste0("#", iter, ": ", round(err, digits=8), "\n")) | |
| } | |
| } | |
| resid <- a[reqno, ] * x[user, ] - y | |
| r2 <- sapply(1 : p, function(k) { | |
| (var(y[, k]) - sum(resid[, k]^2) / nrow(dat)) / var(y[, k]) | |
| }) | |
| colnames(a) <- elemname | |
| colnames(x) <- elemname | |
| list(a=a, x=x, resid=resid, r2=r2) | |
| } | |
| ## 推定 ---- | |
| cl <- makeCluster(4) | |
| # diaglm | |
| diaglm.std <- diaglm(dat.train, threshold=4, weighted=FALSE, cl=cl) | |
| pred.y.by.diaglm.std <- diaglm.std$a[reqno.test, ] * diaglm.std$x[user.test, ] | |
| c(mse=mse(y.test, pred.y.by.diaglm.std), r2=diaglm.std$r2) | |
| # weighted diaglm | |
| diaglm.w <- diaglm(dat.train, threshold=0, weighted=TRUE, cl=cl) | |
| pred.y.by.diaglm.w <- diaglm.w$a[reqno.test, ] * diaglm.w$x[user.test, ] | |
| c(mse=mse(y.test, pred.y.by.diaglm.w), r2=diaglm.w$r2) | |
| # threshold | |
| summary.diaglm.t <- t(sapply(1 : 6, function(threshold) { | |
| diaglm.t <- diaglm(dat.train, threshold=threshold) | |
| pred <- diaglm.t$a[reqno.test, ] * diaglm.t$x[user.test, ] | |
| c(mse=mse(y.test, pred), r2=diaglm.t$r2) | |
| })) | |
| summary.diaglm.t | |
| stopCluster(cl) | |
| ## Interval のみで mean, diaglm を比較 | |
| int.df <- data.frame( | |
| true = y.test[, "chartInterval"], | |
| pred.mean = pred.y.by.mean[, "chartInterval"], | |
| pred.diaglm = pred.y.by.diaglm.std[, "chartInterval"] | |
| ) | |
| g1 <- ggplot(data=int.df, aes(x=true, y=pred.diaglm)) + | |
| geom_point() + | |
| geom_abline(slope=1) + | |
| xlim(50, 100) + ylim(50, 100) + | |
| xlab("True int") + ylab("Predicted int (proposed)") | |
| g2 <- ggplot(data=int.df, aes(x=true, y=pred.mean)) + | |
| geom_point() + | |
| geom_abline(slope=1) + | |
| xlim(50, 100) + ylim(50, 100) + | |
| xlab("True int") + ylab("Predicted int (baseline)") | |
| svg("interval.svg", width=8, height=4) | |
| make.grid(1, 2) | |
| print.at(g1, 1, 1) | |
| print.at(g2, 1, 2) | |
| end.grid() | |
| dev.off() | |
| ## 残差の大きいサンプルを見る ---- | |
| int.df$resid.diaglm <- int.df$pred.diaglm - int.df$true | |
| int.df$resid.mean <- int.df$pred.mean - int.df$true | |
| int.df <- int.df[order(abs(int.df$resid.diaglm), decreasing=TRUE), ] | |
| int.df[1:20, c("true", "resid.diaglm", "resid.mean")] | |
| ## 音程の取りづらい曲は? ---- | |
| cl <- makeCluster(4) | |
| diaglm.full <- diaglm(dat, threshold=4, weighted=FALSE, cl=cl) | |
| stopCluster(cl) | |
| df <- data.frame( | |
| a=diaglm.full$a[, 1], | |
| songs[, c("diffPitch", "contents", "artist")] | |
| ) | |
| df[order(df$a), ][1 : 10, ] | |
| df[order(df$a, decreasing=TRUE), ][1 : 10, ] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment