Skip to content

Instantly share code, notes, and snippets.

@mike-lawrence
Last active December 17, 2020 17:24
Show Gist options
  • Save mike-lawrence/716973647a9656133c49e012f4547103 to your computer and use it in GitHub Desktop.
Save mike-lawrence/716973647a9656133c49e012f4547103 to your computer and use it in GitHub Desktop.
periodic likelihood dashboard
#preamble ----
#load packages used
library(shiny) #for reactive viz
library(tidyverse) #for all that is good and holy
# sin, but frequency & phase parameterized (x should be seconds)
sine = function(x,hz,phase){
x = x*(2*pi)*hz + phase
sin(x)
}
# get log-probability given pars & dat
get_lp = function(pars,dat){
f = pars$amp*sine(dat$x,pars$hz,pars$phase)
obs_lp = sum(dnorm(dat$y-mean(dat$y),f,pars$noise,log=T))
return(obs_lp)
# pure_lp = sum(dnorm(f,f,pars$noise,log=T))
# return(obs_lp-pure_lp)
}
# Generate data ----
true_pars = list(
noise = .1
, amp = 1
, hz = exp(1)/3 #arbitrary but not a multiple of sample rate or time window
, phase = 0
)
seed = 1
set.seed(seed)
dat = (
tibble::tibble(
x = seq(-5,5,by=1/100) # setting this as *not* centered on zero yields diagonal modes
, f = true_pars$amp*sine(x,hz=true_pars$hz,phase=true_pars$phase)
, y = f + rnorm(length(x),0,true_pars$noise)
)
)
# quick viz:
(
dat
%>% ggplot()
+ geom_line(
aes(x=x,y=f)
,colour='green'
, size = 2
)
+ geom_point(
aes(x=x,y=y)
, alpha = .5
)
)
# compute likelihood topography ----
# if not already computed, takes about 3mins on a recentish 6 core machine
if('lp.rds'%in% list.files()){
#if lp has already been computed/saved, load it
lp = readRDS('lp.rds')
}else{
# define the parameter space to map, then split into chunks for parallel processing
pars_chunks =
(
expand_grid(
# hz = seq(.001,3,by=.001)
hz = true_pars$hz*seq(.01,2,by=.01)
, phase = round(seq(-pi*2,pi*2,length.out=1+1e2),15)
, amp = 1
, noise = .1
)
%>% dplyr::mutate(
chunk = (1:n()) %% (parallel::detectCores()/2*10) #10 chunks per core
)
%>% dplyr::group_by(chunk)
%>% dplyr::group_split()
)
#define a function to run on each chunk
get_lp_for_chunk = function(chunk,dat){
return(
chunk
%>% dplyr::group_by(hz,phase,amp,noise)
%>% dplyr::mutate(
value = get_lp(
pars = tibble(hz=hz,phase=phase,amp=amp,noise=noise)
, dat = dat
)
)
)
}
# start processing the parameter space
if('clustermq' %in% installed.packages() ){
#if clustermq/0mq installed, process in parallel
options(clustermq.scheduler="multicore")
lp =
(
clustermq::Q(
fun = get_lp_for_chunk
, chunk = pars_chunks
, const = list(dat=dat)
, n_jobs = parallel::detectCores()/2
)
%>% dplyr::bind_rows()
)
}else{
#if 0mq not available, process serially (takes about 20mins)
start = proc.time()[3]
lp = (
lapply(
FUN = get_lp_for_chunk
, X = pars_chunks
, dat = dat
)
%>% dplyr::bind_rows()
)
proc.time()[3] - start
}
saveRDS(lp,file='lp.rds')
}
# visualize in shiny ----
#rescale lp to range 0 to -1 for convenience
lp_scaled =
(
lp
%>% dplyr::ungroup()
%>% dplyr::mutate(
value = na_if(value,-Inf)
, value = value/diff(range(value,na.rm=T))
, value = value-max(value,na.rm=T)
)
)
# generate a minimally-formatted topo plot to get the legend & raster
temp =
(
lp_scaled
%>% ggplot()
+ geom_raster(
aes(
x = phase
, y = hz
, fill = value
)
)
+ viridis::scale_fill_viridis(
name = 'Relative\nLog-Probability'
, breaks = c(-1,0)
, labels = c('Min','Max')
, guide = guide_colorbar(
title.theme = element_text(size=30)
, label.theme = element_text(size=20)
, barwidth = unit(.1,'npc')
, barheight = unit(.9,'npc')
)
)
)
tmp = ggplot_gtable(ggplot_build(temp))
legend = tmp$grobs[[which(sapply(tmp$grobs, function(x) x$name) == "guide-box") ]]
topo_raster = layer_grob(temp)$`1`$raster
#now a more nicely formatted topo plot, with raster as annotation
topo_plot =
(
lp_scaled
%>% ggplot(
aes(x=phase,y=hz) #necessary for nearPoints
)
+ annotation_raster(
raster = topo_raster
, xmin=min(lp_scaled$phase)
, xmax=max(lp_scaled$phase)
, ymin=min(lp_scaled$hz)
, ymax=max(lp_scaled$hz)
)
+ geom_point(
data = tibble(hz=true_pars$hz,phase=true_pars$phase)
, mapping = aes(x=phase,y=hz)
, shape = 4
)
+ labs(
y = 'Frequency (Hz)'
, x = 'Phase (radians)'
)
+ scale_y_continuous(
expand = c(0,0)
, limits = range(lp$hz)
)
+ scale_x_continuous(
position='top'
, limits = range(lp$phase)
, expand=c(0,0)
, breaks = c(-pi/2,0,pi/2)
, labels = c('-𝜋/2',0,'+𝜋/2')
)
+ viridis::scale_fill_viridis()
+ guides(fill='none')
+ theme(
aspect.ratio = 1
)
)
# normalized autocorrelations (for lags with >=4 observations)
autocorzse = function(x){
nx = length(x)
zse = rep(NA,nx-4)
for(i in 1:(nx-4)){
zse[i] = (
atanh( #atanh "normalizes" correlations (Fisher's r-to-z transform)
cor( x[1:(nx-i)] , x[(1+i):nx] )
)
*sqrt(nx-i-3) #after normalization, z's have an se=1/sqrt(N-3), so multying by sqrt(N-3) makes for unit normal expectation
);
}
return(zse);
}
# UNnormalized autocorrelations (for lags with >=4 observations)
autocor = function(x){
nx = length(x)
r = rep(NA,nx-4)
for(i in 1:(nx-4)){
r[i] = cor( x[1:(nx-i)] , x[(1+i):nx] )
}
return(r);
}
ui <- fluidPage(fluidRow(
column(
width = 7
, fluidRow(
column(
width = 6
, plotOutput("topo",height='500px', click = "plot_click")
)
, column(
width = 6
, plotOutput("hz",height='500px')
)
)
, fluidRow(
column(
width = 6
, plotOutput("phase",height='500px')
)
, column(
width=6
, plotOutput('legend',height='600px')
)
)
)
, column(
width = 4
, offset = 1
, plotOutput("td",height='200px')
, plotOutput("td_resid",height='200px')
, plotOutput("td_ysum",height='200px')
, plotOutput("ac",height='200px')
, plotOutput("ac_zse",height='200px')
)
))
server <- function(input, output, session) {
nearest <- reactiveValues(hz = true_pars$hz, phase = true_pars$phase)
observe({
req(input$plot_click)
nearest_click = nearPoints(lp_scaled, input$plot_click, threshold = 10, maxpoints = 1,addDist = TRUE)
nearest$hz = nearest_click$hz
nearest$phase = nearest_click$phase
# isolate({
# nearest$hz = nearest_click$hz
# nearest$phase = nearest_click$phase
# })
})
output$topo <- renderPlot({
(
topo_plot
+ geom_vline(
xintercept = nearest$phase
, alpha = .5
, linetype = 1
, size = .5
)
+ geom_hline(
yintercept = nearest$hz
, alpha = .5
, linetype = 1
, size = .5
)
)
})
output$legend <- renderPlot({
grid::grid.newpage()
grid::grid.draw(legend)
})
output$hz <- renderPlot({
(
lp_scaled
%>% dplyr::filter(
phase==nearest$phase[1]
)
%>% ggplot(
mapping = aes(
x = hz
, y = value
)
)
+ geom_line(size=4)
+ geom_line(aes(colour = value),size=2)
# + geom_point(aes(colour = value),alpha = .5, size=4)
+ labs(
y = 'Relative Log-Probability'
)
+ scale_x_continuous(expand = c(0,0))
+ scale_y_continuous(
limits = c(-1,0)
, breaks = c(-1,0)
, minor_breaks = c(.75,.5,.25)
, labels = c('Min','Max')
, position = 'right'
)
+ viridis::scale_color_viridis(
limits = c(-1,0)
)
+ guides(colour='none')
+ coord_flip()
+ theme(
aspect.ratio = 1
, axis.title.y = element_text(colour='transparent')
, axis.text.y = element_text(colour='transparent')
)
)
})
output$phase <- renderPlot({
(
lp_scaled
%>% dplyr::filter(
hz==nearest$hz[1]
)
%>% ggplot(
mapping = aes(
x = phase
, y = value
)
)
+ geom_line(size=4)
+ geom_line(aes(colour = value),size=2)
# + geom_point(aes(colour = value),alpha = .5,size=10)
+ labs(
y = 'Relative Log-Probability'
)
+ scale_x_continuous(
position = 'top'
, expand = c(0,0)
, breaks = c(-pi/2,0,pi/2)
)
+ scale_y_continuous(
limits = c(-1,0)
, breaks = c(-1,0)
, minor_breaks = c(.75,.5,.25)
, labels = c('Min','Max')
# , labels = scales::number_format(accuracy=.1)
)
+ viridis::scale_color_viridis(
limits = c(-1,0)
)
+ guides(colour='none')
+ theme(
aspect.ratio = 1
, axis.title.x = element_text(colour='transparent')
, axis.text.x = element_text(colour='transparent')
)
)
})
output$td <- renderPlot({
(
dat
%>% ggplot(mapping=aes(x=x,y=y))
+ geom_point(colour = 'black',alpha=.5)
+ geom_line(
data = tibble::tibble(
x = dat$x
, y = sine(x,hz=nearest$hz[1],phase=nearest$phase[1])
)
, colour = 'red'
, alpha = .5
, size = 2
)
+ labs(
title = 'Time-domain: Observed & Proposal'
, x = 'Time (s)'
, y = 'Value'
)
)
})
output$td_resid <- renderPlot({
(
dat
%>% dplyr::mutate(
y = y - sine(x,hz=nearest$hz[1],phase=nearest$phase[1])
)
%>% ggplot(mapping=aes(x=x,y=y))
+ geom_point(colour = 'black',alpha=.5)
+ labs(
title = 'Time-domain: Residuals (Observed - Proposal)'
, x = 'Time (s)'
, y = 'Observed - Proposal'
)
)
})
output$td_ysum <- renderPlot({
(
dat
%>% dplyr::mutate(
y = y + sine(x,hz=nearest$hz[1],phase=nearest$phase[1])
)
%>% ggplot(mapping=aes(x=x,y=y))
+ geom_point(colour = 'black',alpha=.5)
+ labs(
title = 'Time-domain: Observed + Proposal ("beats"?)'
, x = 'Time (s)'
, y = 'Observed + Proposal'
)
)
})
output$td_fsum <- renderPlot({
(
dat
%>% dplyr::mutate(
y = f + sine(x,hz=nearest$hz[1],phase=nearest$phase[1])
)
%>% ggplot(mapping=aes(x=x,y=y))
+ geom_line()
+ labs(
title = 'Time-domain: True + Proposal ("beats"?)'
, x = 'Time (s)'
, y = 'True + Proposal'
)
)
})
output$ac <- renderPlot({
f = sine(dat$x,nearest$hz[1],nearest$phase[1])
(
tibble::tibble(
y = autocor(dat$y-f)
, x = 1:length(y)
)
%>% ggplot()
+ geom_line(mapping=aes(x=x,y=y))
+ scale_y_continuous(
limits=c(-1,1)
, expand = c(0,0)
, breaks = c(-.5,0,.5)
)
+ labs(
title = 'Residual Autocorrelations: Raw'
, x = 'Lag'
)
+ theme(
axis.title.y = element_blank()
)
)
})
output$ac_zse <- renderPlot({
f = sine(dat$x,nearest$hz[1],nearest$phase[1])
(
tibble::tibble(
y = autocorzse(dat$y-f)
, x = 1:length(y)
)
%>% ggplot()
+ geom_line(mapping=aes(x=x,y=y))
+ geom_ribbon(
data = tibble(x=c(-Inf,Inf),ymin=c(-3,-3),ymax=c(3,3))
, mapping = aes(x=x,ymin=ymin,ymax=ymax)
, fill = 'green'
, colour = 'transparent'
, alpha = .5
)
+ labs(
title = 'Residual Autocorrelations: normalized'
, x = 'Lag'
)
+ theme(
axis.title.y = element_blank()
)
)
})
}
shinyApp(ui, server)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment