# General process would be:
# 1. Simulate a wishart draw x with rWish()
# 2. Transform x to the equivalent free_state, using the bijector but running it in reverse
# 3. Run the log_prob() function on free_state
# 4. Run dWish(..., log = TRUE) on x, and compare with result of step 3
devtools::load_all(".")
#> ℹ Loading greta
#> ℹ Initialising python and checking dependencies, this may take a moment.
#>
#> ✔ Initialising python and checking dependencies ... done!
# 1. Simulate a wishart draw x with rWish() ----------------------------------
set.seed(2024-10-18-1502)
sigma <- matrix(
data = c(1.2, 0.7, 0.7, 2.3),
nrow = 2,
ncol = 2
)
df <- 4
x <- rWishart(10, df, sigma)[1, 2, ]
## 2. Transform x to the equivalent free_state, using the bijector but
## running it in reverse -----------------------------------------------------
a_bijector <- tf_covariance_cholesky_bijector()
# I think I need to convert it into the same dimensions/shape as this
# to run the inverse on it?
a_bijector$forward(fl(x))
#> tf.Tensor(
#> [[ 0.31174 1.45087241 0.67074354 -1.1528149 ]
#> [ 0. 1.50973432 7.29376499 4.66827418]
#> [ 0. 0. 9.86134217 -3.79132253]
#> [ 0. 0. 0. 0.48281132]], shape=(4, 4), dtype=float64)
# from https://stackoverflow.com/questions/56387581/converting-a-vector-in-r-into-a-lower-upper-triangular-matrix-in-specific-order
square_x <- matrix(0, nrow = 4, ncol = 4)
# Upper triangle
square_x[upper.tri(square_x, diag = TRUE)] <- x
square_x
#> [,1] [,2] [,3] [,4]
#> [1,] -0.4770558 -3.791323 -1.152815 7.2937650
#> [2,] 0.0000000 4.668274 -1.005707 0.6707435
#> [3,] 0.0000000 0.000000 9.861280 1.4508724
#> [4,] 0.0000000 0.000000 0.000000 1.2600175
free_state <- a_bijector$inverse(fl(square_x))
free_state
#> tf.Tensor(
#> [ 0.92641839 1.45087241 0.67074354 7.29376499 nan 9.86121786
#> -1.00570714 -1.1528149 -3.79132253 4.65883128], shape=(10), dtype=float64)
# 3. Run the log_prob() function on free_state -------------------------------
x_g <- wishart(df, sigma)[1, 2]
m_g <- model(x_g)
new_log_prob <- m_g$dag$generate_log_prob_function()
m_g$dag$define_tf_log_prob_function()
# prob_input <- matrix(rnorm(12), 4, 3)
log_probs <- new_log_prob(free_state)
#> {{function_node
#> __wrapped__Cholesky_device_/job:localhost/replica:0/task:0/device:CPU:0}} Input
#> matrix must be square. [Op:Cholesky] name:
# fails because input isn't square? Maybe convert back to square?
free_state_r <- as.numeric(free_state)
free_state_mat <- matrix(0, nrow = 4, ncol = 4)
free_state_mat[upper.tri(mat, diag = TRUE)] <- free_state_r
#> Error in eval(expr, envir, enclos): object 'mat' not found
free_state_mat
#> [,1] [,2] [,3] [,4]
#> [1,] 0 0 0 0
#> [2,] 0 0 0 0
#> [3,] 0 0 0 0
#> [4,] 0 0 0 0
fl(free_state_mat)
#> tf.Tensor(
#> [[0. 0. 0. 0.]
#> [0. 0. 0. 0.]
#> [0. 0. 0. 0.]
#> [0. 0. 0. 0.]], shape=(4, 4), dtype=float64)
new_log_prob(fl(free_state_mat))
#> Input right-most shape (4) does not correspond to a triangular matrix.
new_log_prob(free_state_mat)
#> Input right-most shape (4) does not correspond to a triangular matrix.
# not sure what to do here about this
# 4. Run dWish(..., log = TRUE) on x, and compare with result of step 3 ------
dwishart <- function(x, df, Sigma, log = FALSE) { # nolint
ans <- MCMCpack::dwish(W = x, v = df, S = Sigma)
if (log) {
ans <- log(ans)
}
ans
}
x
#> [1] -0.4770558 -3.7913225 4.6682742 -1.1528149 -1.0057071 9.8612800
#> [7] 7.2937650 0.6707435 1.4508724 1.2600175
dwishart(x, df = df, Sigma = sigma, log = TRUE)
#> Error in MCMCpack::dwish(W = x, v = df, S = Sigma): W not square in dwish()
# input must be square?
square_x
#> [,1] [,2] [,3] [,4]
#> [1,] -0.4770558 -3.791323 -1.152815 7.2937650
#> [2,] 0.0000000 4.668274 -1.005707 0.6707435
#> [3,] 0.0000000 0.000000 9.861280 1.4508724
#> [4,] 0.0000000 0.000000 0.000000 1.2600175
# input isn't right dimension...there are negative numbers in square x
dwishart(square_x, df = df, Sigma = sigma, log = TRUE)
#> Error in MCMCpack::dwish(W = x, v = df, S = Sigma): W and X of different dimensionality in dwish()
Created on 2024-10-18 with reprex v2.1.1
Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#> setting value
#> version R version 4.4.1 Patched (2024-07-08 r86915)
#> os macOS Sonoma 14.5
#> system aarch64, darwin20
#> ui X11
#> language (EN)
#> collate en_US.UTF-8
#> ctype en_US.UTF-8
#> tz Australia/Hobart
#> date 2024-10-18
#> pandoc 3.2 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/aarch64/ (via rmarkdown)
#>
#> ─ Packages ───────────────────────────────────────────────────────────────────
#> ! package * version date (UTC) lib source
#> abind 1.4-5 2016-07-21 [1] CRAN (R 4.4.0)
#> backports 1.5.0 2024-05-23 [1] CRAN (R 4.4.0)
#> base64enc 0.1-3 2015-07-28 [1] CRAN (R 4.4.0)
#> brio 1.1.5 2024-04-24 [1] CRAN (R 4.4.0)
#> cachem 1.1.0 2024-05-16 [1] CRAN (R 4.4.0)
#> callr 3.7.6 2024-03-25 [1] CRAN (R 4.4.0)
#> cli 3.6.3 2024-06-21 [1] CRAN (R 4.4.0)
#> coda 0.19-4.1 2024-01-31 [1] CRAN (R 4.4.0)
#> codetools 0.2-20 2024-03-31 [2] CRAN (R 4.4.1)
#> crayon 1.5.3 2024-06-20 [1] CRAN (R 4.4.0)
#> desc 1.4.3 2023-12-10 [1] CRAN (R 4.4.0)
#> devtools 2.4.5 2022-10-11 [1] CRAN (R 4.4.0)
#> digest 0.6.36 2024-06-23 [1] CRAN (R 4.4.0)
#> ellipsis 0.3.2 2021-04-29 [1] CRAN (R 4.4.0)
#> evaluate 0.24.0 2024-06-10 [1] CRAN (R 4.4.0)
#> fansi 1.0.6 2023-12-08 [1] CRAN (R 4.4.0)
#> fastmap 1.2.0 2024-05-15 [1] CRAN (R 4.4.0)
#> fs 1.6.4.9000 2024-06-26 [1] Github (r-lib/fs@714990b)
#> future 1.34.0 2024-07-29 [1] CRAN (R 4.4.0)
#> globals 0.16.3 2024-03-08 [1] CRAN (R 4.4.0)
#> glue 1.7.0 2024-01-09 [1] CRAN (R 4.4.0)
#> P greta * 0.5.0 2024-10-16 [?] load_all()
#> hms 1.1.3 2023-03-21 [1] CRAN (R 4.4.0)
#> htmltools 0.5.8.1 2024-04-04 [1] CRAN (R 4.4.0)
#> htmlwidgets 1.6.4 2023-12-06 [1] CRAN (R 4.4.0)
#> httpuv 1.6.15 2024-03-26 [1] CRAN (R 4.4.0)
#> jsonlite 1.8.8 2023-12-04 [1] CRAN (R 4.4.0)
#> knitr 1.48 2024-07-07 [1] CRAN (R 4.4.0)
#> later 1.3.2 2023-12-06 [1] CRAN (R 4.4.0)
#> lattice 0.22-6 2024-03-20 [2] CRAN (R 4.4.1)
#> lifecycle 1.0.4 2023-11-07 [1] CRAN (R 4.4.0)
#> listenv 0.9.1 2024-01-29 [1] CRAN (R 4.4.0)
#> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.4.0)
#> MASS 7.3-61 2024-06-13 [1] CRAN (R 4.4.0)
#> Matrix 1.7-0 2024-04-26 [2] CRAN (R 4.4.1)
#> MatrixModels 0.5-3 2023-11-06 [1] CRAN (R 4.4.0)
#> mcmc 0.9-8 2023-11-16 [1] CRAN (R 4.4.0)
#> MCMCpack 1.7-0 2024-01-18 [1] CRAN (R 4.4.0)
#> memoise 2.0.1.9000 2024-08-14 [1] Github (hadley/memoise@40db995)
#> mime 0.12 2021-09-28 [1] CRAN (R 4.4.0)
#> miniUI 0.1.1.1 2018-05-18 [1] CRAN (R 4.4.0)
#> parallelly 1.38.0 2024-07-27 [1] CRAN (R 4.4.0)
#> pillar 1.9.0 2023-03-22 [1] CRAN (R 4.4.0)
#> pkgbuild 1.4.4 2024-03-17 [1] CRAN (R 4.4.0)
#> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.4.0)
#> pkgload 1.4.0 2024-06-28 [1] CRAN (R 4.4.0)
#> png 0.1-8 2022-11-29 [1] CRAN (R 4.4.0)
#> prettyunits 1.2.0 2023-09-24 [1] CRAN (R 4.4.0)
#> processx 3.8.4 2024-03-16 [1] CRAN (R 4.4.0)
#> profvis 0.3.8 2023-05-02 [1] CRAN (R 4.4.0)
#> progress 1.2.3 2023-12-06 [1] CRAN (R 4.4.0)
#> promises 1.3.0 2024-04-05 [1] CRAN (R 4.4.0)
#> ps 1.8.0 2024-09-12 [1] CRAN (R 4.4.1)
#> purrr 1.0.2 2023-08-10 [1] CRAN (R 4.4.0)
#> quantreg 5.98 2024-05-26 [1] CRAN (R 4.4.0)
#> R6 2.5.1 2021-08-19 [1] CRAN (R 4.4.0)
#> Rcpp 1.0.13 2024-07-17 [1] CRAN (R 4.4.0)
#> remotes 2.5.0 2024-03-17 [1] CRAN (R 4.4.0)
#> reprex 2.1.1 2024-07-06 [1] CRAN (R 4.4.0)
#> reticulate 1.38.0 2024-06-19 [1] CRAN (R 4.4.0)
#> rlang 1.1.4 2024-06-04 [1] CRAN (R 4.4.0)
#> rmarkdown 2.27 2024-05-17 [1] CRAN (R 4.4.0)
#> rprojroot 2.0.4 2023-11-05 [1] CRAN (R 4.4.0)
#> rstudioapi 0.16.0 2024-03-24 [1] CRAN (R 4.4.0)
#> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.4.0)
#> shiny 1.9.1 2024-08-01 [1] CRAN (R 4.4.0)
#> SparseM 1.84-2 2024-07-17 [1] CRAN (R 4.4.0)
#> stringi 1.8.4 2024-05-06 [1] CRAN (R 4.4.0)
#> stringr 1.5.1 2023-11-14 [1] CRAN (R 4.4.0)
#> survival 3.7-0 2024-06-05 [2] CRAN (R 4.4.1)
#> tensorflow 2.16.0 2024-04-15 [1] CRAN (R 4.4.0)
#> testthat * 3.2.1.1 2024-04-14 [1] CRAN (R 4.4.0)
#> tfautograph 0.3.2 2021-09-17 [1] CRAN (R 4.4.0)
#> tfruns 1.5.3 2024-04-19 [1] CRAN (R 4.4.0)
#> urlchecker 1.0.1.9000 2024-08-27 [1] Github (r-lib/urlchecker@ac38ea4)
#> usethis 3.0.0 2024-07-29 [1] CRAN (R 4.4.0)
#> utf8 1.2.4 2023-10-22 [1] CRAN (R 4.4.0)
#> vctrs 0.6.5 2023-12-01 [1] CRAN (R 4.4.0)
#> whisker 0.4.1 2022-12-05 [1] CRAN (R 4.4.0)
#> withr 3.0.1 2024-07-31 [1] CRAN (R 4.4.0)
#> xfun 0.46 2024-07-18 [1] CRAN (R 4.4.0)
#> xtable 1.8-4 2019-04-21 [1] CRAN (R 4.4.0)
#> yaml 2.3.10 2024-07-26 [1] CRAN (R 4.4.0)
#> yesno 0.1.3 2024-07-26 [1] CRAN (R 4.4.0)
#>
#> [1] /Users/nick/Library/R/arm64/4.4/library
#> [2] /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library
#>
#> P ── Loaded and on-disk path mismatch.
#>
#> ─ Python configuration ───────────────────────────────────────────────────────
#> python: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/bin/python
#> libpython: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/libpython3.10.dylib
#> pythonhome: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2:/Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2
#> version: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:51:49) [Clang 16.0.6 ]
#> numpy: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.10/site-packages/numpy
#> numpy_version: 1.26.4
#> tensorflow: /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.10/site-packages/tensorflow
#>
#> NOTE: Python version was forced by use_python() function
#>
#> ──────────────────────────────────────────────────────────────────────────────