-
-
Save mike-lawrence/716973647a9656133c49e012f4547103 to your computer and use it in GitHub Desktop.
periodic likelihood dashboard
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#preamble ---- | |
#load packages used | |
library(shiny) #for reactive viz | |
library(tidyverse) #for all that is good and holy | |
# sin, but frequency & phase parameterized (x should be seconds) | |
sine = function(x,hz,phase){ | |
x = x*(2*pi)*hz + phase | |
sin(x) | |
} | |
# get log-probability given pars & dat | |
get_lp = function(pars,dat){ | |
f = pars$amp*sine(dat$x,pars$hz,pars$phase) | |
obs_lp = sum(dnorm(dat$y-mean(dat$y),f,pars$noise,log=T)) | |
return(obs_lp) | |
# pure_lp = sum(dnorm(f,f,pars$noise,log=T)) | |
# return(obs_lp-pure_lp) | |
} | |
# Generate data ---- | |
true_pars = list( | |
noise = .1 | |
, amp = 1 | |
, hz = exp(1)/3 #arbitrary but not a multiple of sample rate or time window | |
, phase = 0 | |
) | |
seed = 1 | |
set.seed(seed) | |
dat = ( | |
tibble::tibble( | |
x = seq(-5,5,by=1/100) # setting this as *not* centered on zero yields diagonal modes | |
, f = true_pars$amp*sine(x,hz=true_pars$hz,phase=true_pars$phase) | |
, y = f + rnorm(length(x),0,true_pars$noise) | |
) | |
) | |
# quick viz: | |
( | |
dat | |
%>% ggplot() | |
+ geom_line( | |
aes(x=x,y=f) | |
,colour='green' | |
, size = 2 | |
) | |
+ geom_point( | |
aes(x=x,y=y) | |
, alpha = .5 | |
) | |
) | |
# compute likelihood topography ---- | |
# if not already computed, takes about 3mins on a recentish 6 core machine | |
if('lp.rds'%in% list.files()){ | |
#if lp has already been computed/saved, load it | |
lp = readRDS('lp.rds') | |
}else{ | |
# define the parameter space to map, then split into chunks for parallel processing | |
pars_chunks = | |
( | |
expand_grid( | |
# hz = seq(.001,3,by=.001) | |
hz = true_pars$hz*seq(.01,2,by=.01) | |
, phase = round(seq(-pi*2,pi*2,length.out=1+1e2),15) | |
, amp = 1 | |
, noise = .1 | |
) | |
%>% dplyr::mutate( | |
chunk = (1:n()) %% (parallel::detectCores()/2*10) #10 chunks per core | |
) | |
%>% dplyr::group_by(chunk) | |
%>% dplyr::group_split() | |
) | |
#define a function to run on each chunk | |
get_lp_for_chunk = function(chunk,dat){ | |
return( | |
chunk | |
%>% dplyr::group_by(hz,phase,amp,noise) | |
%>% dplyr::mutate( | |
value = get_lp( | |
pars = tibble(hz=hz,phase=phase,amp=amp,noise=noise) | |
, dat = dat | |
) | |
) | |
) | |
} | |
# start processing the parameter space | |
if('clustermq' %in% installed.packages() ){ | |
#if clustermq/0mq installed, process in parallel | |
options(clustermq.scheduler="multicore") | |
lp = | |
( | |
clustermq::Q( | |
fun = get_lp_for_chunk | |
, chunk = pars_chunks | |
, const = list(dat=dat) | |
, n_jobs = parallel::detectCores()/2 | |
) | |
%>% dplyr::bind_rows() | |
) | |
}else{ | |
#if 0mq not available, process serially (takes about 20mins) | |
start = proc.time()[3] | |
lp = ( | |
lapply( | |
FUN = get_lp_for_chunk | |
, X = pars_chunks | |
, dat = dat | |
) | |
%>% dplyr::bind_rows() | |
) | |
proc.time()[3] - start | |
} | |
saveRDS(lp,file='lp.rds') | |
} | |
# visualize in shiny ---- | |
#rescale lp to range 0 to -1 for convenience | |
lp_scaled = | |
( | |
lp | |
%>% dplyr::ungroup() | |
%>% dplyr::mutate( | |
value = na_if(value,-Inf) | |
, value = value/diff(range(value,na.rm=T)) | |
, value = value-max(value,na.rm=T) | |
) | |
) | |
# generate a minimally-formatted topo plot to get the legend & raster | |
temp = | |
( | |
lp_scaled | |
%>% ggplot() | |
+ geom_raster( | |
aes( | |
x = phase | |
, y = hz | |
, fill = value | |
) | |
) | |
+ viridis::scale_fill_viridis( | |
name = 'Relative\nLog-Probability' | |
, breaks = c(-1,0) | |
, labels = c('Min','Max') | |
, guide = guide_colorbar( | |
title.theme = element_text(size=30) | |
, label.theme = element_text(size=20) | |
, barwidth = unit(.1,'npc') | |
, barheight = unit(.9,'npc') | |
) | |
) | |
) | |
tmp = ggplot_gtable(ggplot_build(temp)) | |
legend = tmp$grobs[[which(sapply(tmp$grobs, function(x) x$name) == "guide-box") ]] | |
topo_raster = layer_grob(temp)$`1`$raster | |
#now a more nicely formatted topo plot, with raster as annotation | |
topo_plot = | |
( | |
lp_scaled | |
%>% ggplot( | |
aes(x=phase,y=hz) #necessary for nearPoints | |
) | |
+ annotation_raster( | |
raster = topo_raster | |
, xmin=min(lp_scaled$phase) | |
, xmax=max(lp_scaled$phase) | |
, ymin=min(lp_scaled$hz) | |
, ymax=max(lp_scaled$hz) | |
) | |
+ geom_point( | |
data = tibble(hz=true_pars$hz,phase=true_pars$phase) | |
, mapping = aes(x=phase,y=hz) | |
, shape = 4 | |
) | |
+ labs( | |
y = 'Frequency (Hz)' | |
, x = 'Phase (radians)' | |
) | |
+ scale_y_continuous( | |
expand = c(0,0) | |
, limits = range(lp$hz) | |
) | |
+ scale_x_continuous( | |
position='top' | |
, limits = range(lp$phase) | |
, expand=c(0,0) | |
, breaks = c(-pi/2,0,pi/2) | |
, labels = c('-𝜋/2',0,'+𝜋/2') | |
) | |
+ viridis::scale_fill_viridis() | |
+ guides(fill='none') | |
+ theme( | |
aspect.ratio = 1 | |
) | |
) | |
# normalized autocorrelations (for lags with >=4 observations) | |
autocorzse = function(x){ | |
nx = length(x) | |
zse = rep(NA,nx-4) | |
for(i in 1:(nx-4)){ | |
zse[i] = ( | |
atanh( #atanh "normalizes" correlations (Fisher's r-to-z transform) | |
cor( x[1:(nx-i)] , x[(1+i):nx] ) | |
) | |
*sqrt(nx-i-3) #after normalization, z's have an se=1/sqrt(N-3), so multying by sqrt(N-3) makes for unit normal expectation | |
); | |
} | |
return(zse); | |
} | |
# UNnormalized autocorrelations (for lags with >=4 observations) | |
autocor = function(x){ | |
nx = length(x) | |
r = rep(NA,nx-4) | |
for(i in 1:(nx-4)){ | |
r[i] = cor( x[1:(nx-i)] , x[(1+i):nx] ) | |
} | |
return(r); | |
} | |
ui <- fluidPage(fluidRow( | |
column( | |
width = 7 | |
, fluidRow( | |
column( | |
width = 6 | |
, plotOutput("topo",height='500px', click = "plot_click") | |
) | |
, column( | |
width = 6 | |
, plotOutput("hz",height='500px') | |
) | |
) | |
, fluidRow( | |
column( | |
width = 6 | |
, plotOutput("phase",height='500px') | |
) | |
, column( | |
width=6 | |
, plotOutput('legend',height='600px') | |
) | |
) | |
) | |
, column( | |
width = 4 | |
, offset = 1 | |
, plotOutput("td",height='200px') | |
, plotOutput("td_resid",height='200px') | |
, plotOutput("td_ysum",height='200px') | |
, plotOutput("ac",height='200px') | |
, plotOutput("ac_zse",height='200px') | |
) | |
)) | |
server <- function(input, output, session) { | |
nearest <- reactiveValues(hz = true_pars$hz, phase = true_pars$phase) | |
observe({ | |
req(input$plot_click) | |
nearest_click = nearPoints(lp_scaled, input$plot_click, threshold = 10, maxpoints = 1,addDist = TRUE) | |
nearest$hz = nearest_click$hz | |
nearest$phase = nearest_click$phase | |
# isolate({ | |
# nearest$hz = nearest_click$hz | |
# nearest$phase = nearest_click$phase | |
# }) | |
}) | |
output$topo <- renderPlot({ | |
( | |
topo_plot | |
+ geom_vline( | |
xintercept = nearest$phase | |
, alpha = .5 | |
, linetype = 1 | |
, size = .5 | |
) | |
+ geom_hline( | |
yintercept = nearest$hz | |
, alpha = .5 | |
, linetype = 1 | |
, size = .5 | |
) | |
) | |
}) | |
output$legend <- renderPlot({ | |
grid::grid.newpage() | |
grid::grid.draw(legend) | |
}) | |
output$hz <- renderPlot({ | |
( | |
lp_scaled | |
%>% dplyr::filter( | |
phase==nearest$phase[1] | |
) | |
%>% ggplot( | |
mapping = aes( | |
x = hz | |
, y = value | |
) | |
) | |
+ geom_line(size=4) | |
+ geom_line(aes(colour = value),size=2) | |
# + geom_point(aes(colour = value),alpha = .5, size=4) | |
+ labs( | |
y = 'Relative Log-Probability' | |
) | |
+ scale_x_continuous(expand = c(0,0)) | |
+ scale_y_continuous( | |
limits = c(-1,0) | |
, breaks = c(-1,0) | |
, minor_breaks = c(.75,.5,.25) | |
, labels = c('Min','Max') | |
, position = 'right' | |
) | |
+ viridis::scale_color_viridis( | |
limits = c(-1,0) | |
) | |
+ guides(colour='none') | |
+ coord_flip() | |
+ theme( | |
aspect.ratio = 1 | |
, axis.title.y = element_text(colour='transparent') | |
, axis.text.y = element_text(colour='transparent') | |
) | |
) | |
}) | |
output$phase <- renderPlot({ | |
( | |
lp_scaled | |
%>% dplyr::filter( | |
hz==nearest$hz[1] | |
) | |
%>% ggplot( | |
mapping = aes( | |
x = phase | |
, y = value | |
) | |
) | |
+ geom_line(size=4) | |
+ geom_line(aes(colour = value),size=2) | |
# + geom_point(aes(colour = value),alpha = .5,size=10) | |
+ labs( | |
y = 'Relative Log-Probability' | |
) | |
+ scale_x_continuous( | |
position = 'top' | |
, expand = c(0,0) | |
, breaks = c(-pi/2,0,pi/2) | |
) | |
+ scale_y_continuous( | |
limits = c(-1,0) | |
, breaks = c(-1,0) | |
, minor_breaks = c(.75,.5,.25) | |
, labels = c('Min','Max') | |
# , labels = scales::number_format(accuracy=.1) | |
) | |
+ viridis::scale_color_viridis( | |
limits = c(-1,0) | |
) | |
+ guides(colour='none') | |
+ theme( | |
aspect.ratio = 1 | |
, axis.title.x = element_text(colour='transparent') | |
, axis.text.x = element_text(colour='transparent') | |
) | |
) | |
}) | |
output$td <- renderPlot({ | |
( | |
dat | |
%>% ggplot(mapping=aes(x=x,y=y)) | |
+ geom_point(colour = 'black',alpha=.5) | |
+ geom_line( | |
data = tibble::tibble( | |
x = dat$x | |
, y = sine(x,hz=nearest$hz[1],phase=nearest$phase[1]) | |
) | |
, colour = 'red' | |
, alpha = .5 | |
, size = 2 | |
) | |
+ labs( | |
title = 'Time-domain: Observed & Proposal' | |
, x = 'Time (s)' | |
, y = 'Value' | |
) | |
) | |
}) | |
output$td_resid <- renderPlot({ | |
( | |
dat | |
%>% dplyr::mutate( | |
y = y - sine(x,hz=nearest$hz[1],phase=nearest$phase[1]) | |
) | |
%>% ggplot(mapping=aes(x=x,y=y)) | |
+ geom_point(colour = 'black',alpha=.5) | |
+ labs( | |
title = 'Time-domain: Residuals (Observed - Proposal)' | |
, x = 'Time (s)' | |
, y = 'Observed - Proposal' | |
) | |
) | |
}) | |
output$td_ysum <- renderPlot({ | |
( | |
dat | |
%>% dplyr::mutate( | |
y = y + sine(x,hz=nearest$hz[1],phase=nearest$phase[1]) | |
) | |
%>% ggplot(mapping=aes(x=x,y=y)) | |
+ geom_point(colour = 'black',alpha=.5) | |
+ labs( | |
title = 'Time-domain: Observed + Proposal ("beats"?)' | |
, x = 'Time (s)' | |
, y = 'Observed + Proposal' | |
) | |
) | |
}) | |
output$td_fsum <- renderPlot({ | |
( | |
dat | |
%>% dplyr::mutate( | |
y = f + sine(x,hz=nearest$hz[1],phase=nearest$phase[1]) | |
) | |
%>% ggplot(mapping=aes(x=x,y=y)) | |
+ geom_line() | |
+ labs( | |
title = 'Time-domain: True + Proposal ("beats"?)' | |
, x = 'Time (s)' | |
, y = 'True + Proposal' | |
) | |
) | |
}) | |
output$ac <- renderPlot({ | |
f = sine(dat$x,nearest$hz[1],nearest$phase[1]) | |
( | |
tibble::tibble( | |
y = autocor(dat$y-f) | |
, x = 1:length(y) | |
) | |
%>% ggplot() | |
+ geom_line(mapping=aes(x=x,y=y)) | |
+ scale_y_continuous( | |
limits=c(-1,1) | |
, expand = c(0,0) | |
, breaks = c(-.5,0,.5) | |
) | |
+ labs( | |
title = 'Residual Autocorrelations: Raw' | |
, x = 'Lag' | |
) | |
+ theme( | |
axis.title.y = element_blank() | |
) | |
) | |
}) | |
output$ac_zse <- renderPlot({ | |
f = sine(dat$x,nearest$hz[1],nearest$phase[1]) | |
( | |
tibble::tibble( | |
y = autocorzse(dat$y-f) | |
, x = 1:length(y) | |
) | |
%>% ggplot() | |
+ geom_line(mapping=aes(x=x,y=y)) | |
+ geom_ribbon( | |
data = tibble(x=c(-Inf,Inf),ymin=c(-3,-3),ymax=c(3,3)) | |
, mapping = aes(x=x,ymin=ymin,ymax=ymax) | |
, fill = 'green' | |
, colour = 'transparent' | |
, alpha = .5 | |
) | |
+ labs( | |
title = 'Residual Autocorrelations: normalized' | |
, x = 'Lag' | |
) | |
+ theme( | |
axis.title.y = element_blank() | |
) | |
) | |
}) | |
} | |
shinyApp(ui, server) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment