Skip to content

Instantly share code, notes, and snippets.

@yk-tanigawa
Last active April 3, 2020 18:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yk-tanigawa/aeb3275bd7e93af009d39daed6cb8df9 to your computer and use it in GitHub Desktop.
Save yk-tanigawa/aeb3275bd7e93af009d39daed6cb8df9 to your computer and use it in GitHub Desktop.
Multinomial model in brms

Multinomial model in brms

Yosuke Tanigawa

2020/4/2

We are trying to fit multinomial model using brms and it seems like we are not able to get a stable fit. We would love to get some advise.

Specifically, when we try to fit 12 chains, and only one of them have samples. When we try to lower the threshold of N (min_n) and include more countries (columns), none of the 12 chain contained samples.

Our function call is in this form. We used LKJ(2) as the prior for the correlation.

fit <- brm(
    response | trials(n) ~ 1 + (1 | ID | obs),
    family=multinomial(),
    data=A,
    control = brm_control,
    prior = c(
        set_prior("lkj(2)", class="cor")
    ),    
    chains=12,
    iter = brm_iter,
    cores=6,
)

We observed those warning messages:

Chain @: Exception: validate transformed params: y is not positive definite.  (in 'modele@@@' at line 493)

We used the latest version (2020/3/31 59a47b97949a4ef77b9398395b94262272bb3d0f) of brms from GitHub.

brms.fit.R is the script used in our analysis.

Our input file: https://github.com/rivas-lab/covid19/blob/master/HLA/HLA.2digits.wide.tables/HLA-A/UKB.HLA.country_of_birth.2digits.wide.tsv

Our output file (that has input matrix A) is uploaded to Google Drive: https://drive.google.com/open?id=1wc0I7bn7Qjwm9JJ_Zz1XzCArrPkrKF9G

suppressPackageStartupMessages(
devtools::load_all('@@@/paul-buerkner/brms')
)
suppressPackageStartupMessages(library(tidyverse))
suppressPackageStartupMessages(library(data.table))
############################
# Input/output files and run-time parameters
in_freq_f <- 'HLA-A/UKB.HLA.PCA_pop.2digits.wide.tsv'
out_rdata_f <- 'HLA-A/UKB.HLA.PCA_pop.2digits.wide.brms.t.minN.700.fit.1.RData'
min_n <- 700
min_allele_n <- 20
brm_control <- list(adapt_delta = 0.9, max_treedepth=10)
brm_iter <- 4000
############################
# helper functions
transpose_df <- function(df, samples, alleles){
t_df <- df %>%
select(alleles) %>%
as.matrix() %>%
t() %>%
as.data.frame()
colnames(t_df)<-samples
t_df <- t_df %>%
rownames_to_column('sample_group')
t_df
}
prep_brm_input <- function(df, samples, alleles, alleles_prefix='n_'){
# Prepare an input data frame for brms::brm()
# We assume the input df has the following columns:
# - sample_group whose entries are specified in samples
# - n_<allele> where allele are elements of alleles
sub_df <- df %>%
select(sample_group, paste0(alleles_prefix, alleles)) %>%
filter(sample_group %in% samples) %>%
mutate(sample_group = as.character(sample_group))
A <- sub_df %>%
left_join(
sub_df %>%
gather(allele, n, -sample_group) %>%
group_by(sample_group) %>%
summarise(n = sum(n)) %>%
ungroup(),
by='sample_group'
) %>%
mutate(
obs = 1:n()
)
A$response <- sub_df %>%
select(paste0(alleles_prefix, alleles)) %>%
as.matrix()
A
}
############################
# read the input file
df <- data.table::fread(in_freq_f, data.table=F, stringsAsFactors=F)
sample_group <- str_replace(names(df)[1], '#', '')
names(df)[1] <- 'sample_group'
samples <- df %>%
filter(n>min_n) %>%
pull(sample_group)
alleles <- colnames(df)[2:(ncol(df)-1)]
alleles <- df %>%
filter(sample_group %in% samples) %>%
select(sample_group, alleles) %>%
gather(allele, n, -sample_group) %>%
group_by(allele) %>%
summarise(n = sum(n)) %>%
filter(n > min_allele_n) %>%
pull(allele)
############################
# prepare a data frame for `brms::brm`
A <- df %>%
filter(sample_group %in% samples ) %>%
transpose_df(samples, alleles) %>%
prep_brm_input(alleles, samples, '') %>%
filter(n != 0)
############################
# brms::brm
fit <- brm(
response | trials(n) ~ 1 + (1 | ID | obs),
family=multinomial(),
data=A,
control = brm_control,
prior = c(
set_prior("lkj(2)", class="cor")
),
chains=12,
iter = brm_iter,
cores=6,
)
############################
# save results to a file
save(
file = out_rdata_f,
fit=fit, A=A, df=df,
samples=samples, alleles=alleles,
sample_group=sample_group,
brm_control=brm_control, brm_iter = brm_iter
)
@yk-tanigawa
Copy link
Author

And those are the contents of the input data frames.

Screenshot 2020-04-03 00 05 34
Screenshot 2020-04-03 00 05 55

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment