Skip to content

Instantly share code, notes, and snippets.

@maxdrohde
Created January 24, 2022 16:53
Show Gist options
  • Save maxdrohde/717ace7cd71fbea89fd2001d5f145e78 to your computer and use it in GitHub Desktop.
Save maxdrohde/717ace7cd71fbea89fd2001d5f145e78 to your computer and use it in GitHub Desktop.
library(MASS)
library(plotrix)
library(tidyverse)
library(gganimate)
####################################
## Example 1 (Data by Group with means and estimated means)
####################################
## Number of groups, observations per group
k=30
obs=10
top=10
## Generate data
sig2=16
d1=diag(1,k)
mu=round(runif(n=k,0,top),1) ## mu=round(rnorm(n=k,mean=5,sd=1),1)
data=mvrnorm(n=obs, mu, sig2*d1)
## Group Indicator
grp=c(seq(1,k,1))
true.mean=cbind(grp,mu)
## Put in Data.frame (wide format; transpose)
wdat=as.data.frame(cbind(t(data),grp))
grp.mean=cbind(grp,rowMeans(wdat[,1:obs]))
grand=mean(data)
ss=sum((grp.mean[,2]-grand)^2)
wdat=as.data.frame(cbind(t(data),grp,grp.mean[,2],rep(grand,k)))
names(wdat)[(obs+1):(obs+3)]=c("grp","gmean","grand")
## JS estimates (Efrom-Morris extension)
b=(k-3)*(sig2/obs)/ss
jsest=grand+(1-b)*(grp.mean[,2]-grand)
js.mean=cbind(grp,jsest)
## Observed Loss
grp.ls=(grp.mean[,2]-true.mean[,2])^2
js.ls=(js.mean[,2]-true.mean[,2])^2
loss=cbind(grp,grp.ls,js.ls,c(grp.ls>js.ls))
grp.tl=round(sum(loss[,2]),2)
js.tl=round(sum(loss[,3]),2)
js.beat=sum(loss[,4])
## Reshape to long format
long=reshape(wdat,direction="long",varying=list(names(wdat[,1:obs])),v.names="y",
idvar="grp",timevar="obsn")
ldat=long[order(long$grp),]
####
df <- left_join(ldat, as_tibble(js.mean), by="grp") %>% pivot_longer(cols=c("gmean", "jsest"))
df_small <-
df %>% distinct(grp, name, .keep_all = TRUE)
df_small$frame <- ifelse(df_small$name=="gmean", 1, 2)
df_small$name <- recode(df_small$name,
`gmean`="Group Mean",
`jsest`="James-Stein")
anim <-
ggplot() +
geom_jitter(data=df,
mapping = aes(x=grp, y=y),
width=0.2, size=0.75, alpha=0.5) +
geom_point(data=df_small,
mapping=aes(x=grp,
y=value,
fill=name,
group=grp),
size=2,
shape=21,
color="black") +
geom_hline(data=df_small,
aes(yintercept=grand), linetype=2) +
labs(title="Shrinkage: James-Stein Estimator",
x = "Group",
y = "Value") +
cowplot::theme_cowplot(font_family = "Source Sans Pro",
font_size = 10) +
theme(legend.title = element_blank()) +
transition_states(frame) +
ease_aes('quartic-in-out')
#shadow_wake(wake_length = 0.4)
gif <- animate(anim,
duration=4,
fps=60,
detail=20,
height = 4,
width = 5,
units = "in",
res = 300,
renderer = ffmpeg_renderer())
# Save to mp4
anim_save(animation = gif, filename = "js_anim.mp4")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment