Skip to content

Instantly share code, notes, and snippets.

@njtierney
Created October 18, 2024 04:04
Show Gist options
  • Save njtierney/0fc60f091bc035ea7d218b6a36672aab to your computer and use it in GitHub Desktop.
Save njtierney/0fc60f091bc035ea7d218b6a36672aab to your computer and use it in GitHub Desktop.
# General process would be:
# 1. Simulate a wishart draw x with rWish()
# 2. Transform x to the equivalent free_state, using the bijector but running it in reverse
# 3. Run the log_prob() function on free_state
# 4. Run dWish(..., log = TRUE) on x, and compare with result of step 3
devtools::load_all(".")
#> ℹ Loading greta
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> 
#> ✔ Initialising python and checking dependencies ... done!

# 1. Simulate a wishart draw x with rWish() ----------------------------------
set.seed(2024-10-18-1502)
sigma <- matrix(
  data = c(1.2, 0.7, 0.7, 2.3),
  nrow = 2,
  ncol = 2
)

df <- 4

x <- rWishart(10, df, sigma)[1, 2, ]
## 2. Transform x to the equivalent free_state, using the bijector but 
## running it in reverse -----------------------------------------------------
a_bijector <- tf_covariance_cholesky_bijector()

# I think I need to convert it into the same dimensions/shape as this 
# to run the inverse on it?
a_bijector$forward(fl(x))
#> tf.Tensor(
#> [[ 0.31174     1.45087241  0.67074354 -1.1528149 ]
#>  [ 0.          1.50973432  7.29376499  4.66827418]
#>  [ 0.          0.          9.86134217 -3.79132253]
#>  [ 0.          0.          0.          0.48281132]], shape=(4, 4), dtype=float64)

# from https://stackoverflow.com/questions/56387581/converting-a-vector-in-r-into-a-lower-upper-triangular-matrix-in-specific-order
square_x <- matrix(0, nrow = 4, ncol = 4)
# Upper triangle
square_x[upper.tri(square_x, diag = TRUE)] <- x
square_x
#>            [,1]      [,2]      [,3]      [,4]
#> [1,] -0.4770558 -3.791323 -1.152815 7.2937650
#> [2,]  0.0000000  4.668274 -1.005707 0.6707435
#> [3,]  0.0000000  0.000000  9.861280 1.4508724
#> [4,]  0.0000000  0.000000  0.000000 1.2600175

free_state <- a_bijector$inverse(fl(square_x))
free_state
#> tf.Tensor(
#> [ 0.92641839  1.45087241  0.67074354  7.29376499         nan  9.86121786
#>  -1.00570714 -1.1528149  -3.79132253  4.65883128], shape=(10), dtype=float64)

# 3. Run the log_prob() function on free_state -------------------------------
x_g <- wishart(df, sigma)[1, 2]
m_g <- model(x_g)
new_log_prob <- m_g$dag$generate_log_prob_function()
m_g$dag$define_tf_log_prob_function()
# prob_input <- matrix(rnorm(12), 4, 3)
log_probs <- new_log_prob(free_state)
#> {{function_node
#> __wrapped__Cholesky_device_/job:localhost/replica:0/task:0/device:CPU:0}} Input
#> matrix must be square. [Op:Cholesky] name:

# fails because input isn't square? Maybe convert back to square?
free_state_r <- as.numeric(free_state)
free_state_mat <- matrix(0, nrow = 4, ncol = 4)
free_state_mat[upper.tri(mat, diag = TRUE)] <- free_state_r
#> Error in eval(expr, envir, enclos): object 'mat' not found
free_state_mat
#>      [,1] [,2] [,3] [,4]
#> [1,]    0    0    0    0
#> [2,]    0    0    0    0
#> [3,]    0    0    0    0
#> [4,]    0    0    0    0

fl(free_state_mat)
#> tf.Tensor(
#> [[0. 0. 0. 0.]
#>  [0. 0. 0. 0.]
#>  [0. 0. 0. 0.]
#>  [0. 0. 0. 0.]], shape=(4, 4), dtype=float64)
new_log_prob(fl(free_state_mat))
#> Input right-most shape (4) does not correspond to a triangular matrix.
new_log_prob(free_state_mat)
#> Input right-most shape (4) does not correspond to a triangular matrix.

# not sure what to do here about this

# 4. Run dWish(..., log = TRUE) on x, and compare with result of step 3 ------
dwishart <- function(x, df, Sigma, log = FALSE) { # nolint
  ans <- MCMCpack::dwish(W = x, v = df, S = Sigma)
  if (log) {
    ans <- log(ans)
  }
  ans
}
x
#>  [1] -0.4770558 -3.7913225  4.6682742 -1.1528149 -1.0057071  9.8612800
#>  [7]  7.2937650  0.6707435  1.4508724  1.2600175
dwishart(x, df = df, Sigma = sigma, log = TRUE)
#> Error in MCMCpack::dwish(W = x, v = df, S = Sigma): W not square in dwish()
# input must be square?
square_x
#>            [,1]      [,2]      [,3]      [,4]
#> [1,] -0.4770558 -3.791323 -1.152815 7.2937650
#> [2,]  0.0000000  4.668274 -1.005707 0.6707435
#> [3,]  0.0000000  0.000000  9.861280 1.4508724
#> [4,]  0.0000000  0.000000  0.000000 1.2600175
# input isn't right dimension...there are negative numbers in square x
dwishart(square_x, df = df, Sigma = sigma, log = TRUE)
#> Error in MCMCpack::dwish(W = x, v = df, S = Sigma): W and X of different dimensionality in dwish()

Created on 2024-10-18 with reprex v2.1.1

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.4.1 Patched (2024-07-08 r86915)
#>  os       macOS Sonoma 14.5
#>  system   aarch64, darwin20
#>  ui       X11
#>  language (EN)
#>  collate  en_US.UTF-8
#>  ctype    en_US.UTF-8
#>  tz       Australia/Hobart
#>  date     2024-10-18
#>  pandoc   3.2 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/aarch64/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  ! package      * version    date (UTC) lib source
#>    abind          1.4-5      2016-07-21 [1] CRAN (R 4.4.0)
#>    backports      1.5.0      2024-05-23 [1] CRAN (R 4.4.0)
#>    base64enc      0.1-3      2015-07-28 [1] CRAN (R 4.4.0)
#>    brio           1.1.5      2024-04-24 [1] CRAN (R 4.4.0)
#>    cachem         1.1.0      2024-05-16 [1] CRAN (R 4.4.0)
#>    callr          3.7.6      2024-03-25 [1] CRAN (R 4.4.0)
#>    cli            3.6.3      2024-06-21 [1] CRAN (R 4.4.0)
#>    coda           0.19-4.1   2024-01-31 [1] CRAN (R 4.4.0)
#>    codetools      0.2-20     2024-03-31 [2] CRAN (R 4.4.1)
#>    crayon         1.5.3      2024-06-20 [1] CRAN (R 4.4.0)
#>    desc           1.4.3      2023-12-10 [1] CRAN (R 4.4.0)
#>    devtools       2.4.5      2022-10-11 [1] CRAN (R 4.4.0)
#>    digest         0.6.36     2024-06-23 [1] CRAN (R 4.4.0)
#>    ellipsis       0.3.2      2021-04-29 [1] CRAN (R 4.4.0)
#>    evaluate       0.24.0     2024-06-10 [1] CRAN (R 4.4.0)
#>    fansi          1.0.6      2023-12-08 [1] CRAN (R 4.4.0)
#>    fastmap        1.2.0      2024-05-15 [1] CRAN (R 4.4.0)
#>    fs             1.6.4.9000 2024-06-26 [1] Github (r-lib/fs@714990b)
#>    future         1.34.0     2024-07-29 [1] CRAN (R 4.4.0)
#>    globals        0.16.3     2024-03-08 [1] CRAN (R 4.4.0)
#>    glue           1.7.0      2024-01-09 [1] CRAN (R 4.4.0)
#>  P greta        * 0.5.0      2024-10-16 [?] load_all()
#>    hms            1.1.3      2023-03-21 [1] CRAN (R 4.4.0)
#>    htmltools      0.5.8.1    2024-04-04 [1] CRAN (R 4.4.0)
#>    htmlwidgets    1.6.4      2023-12-06 [1] CRAN (R 4.4.0)
#>    httpuv         1.6.15     2024-03-26 [1] CRAN (R 4.4.0)
#>    jsonlite       1.8.8      2023-12-04 [1] CRAN (R 4.4.0)
#>    knitr          1.48       2024-07-07 [1] CRAN (R 4.4.0)
#>    later          1.3.2      2023-12-06 [1] CRAN (R 4.4.0)
#>    lattice        0.22-6     2024-03-20 [2] CRAN (R 4.4.1)
#>    lifecycle      1.0.4      2023-11-07 [1] CRAN (R 4.4.0)
#>    listenv        0.9.1      2024-01-29 [1] CRAN (R 4.4.0)
#>    magrittr       2.0.3      2022-03-30 [1] CRAN (R 4.4.0)
#>    MASS           7.3-61     2024-06-13 [1] CRAN (R 4.4.0)
#>    Matrix         1.7-0      2024-04-26 [2] CRAN (R 4.4.1)
#>    MatrixModels   0.5-3      2023-11-06 [1] CRAN (R 4.4.0)
#>    mcmc           0.9-8      2023-11-16 [1] CRAN (R 4.4.0)
#>    MCMCpack       1.7-0      2024-01-18 [1] CRAN (R 4.4.0)
#>    memoise        2.0.1.9000 2024-08-14 [1] Github (hadley/memoise@40db995)
#>    mime           0.12       2021-09-28 [1] CRAN (R 4.4.0)
#>    miniUI         0.1.1.1    2018-05-18 [1] CRAN (R 4.4.0)
#>    parallelly     1.38.0     2024-07-27 [1] CRAN (R 4.4.0)
#>    pillar         1.9.0      2023-03-22 [1] CRAN (R 4.4.0)
#>    pkgbuild       1.4.4      2024-03-17 [1] CRAN (R 4.4.0)
#>    pkgconfig      2.0.3      2019-09-22 [1] CRAN (R 4.4.0)
#>    pkgload        1.4.0      2024-06-28 [1] CRAN (R 4.4.0)
#>    png            0.1-8      2022-11-29 [1] CRAN (R 4.4.0)
#>    prettyunits    1.2.0      2023-09-24 [1] CRAN (R 4.4.0)
#>    processx       3.8.4      2024-03-16 [1] CRAN (R 4.4.0)
#>    profvis        0.3.8      2023-05-02 [1] CRAN (R 4.4.0)
#>    progress       1.2.3      2023-12-06 [1] CRAN (R 4.4.0)
#>    promises       1.3.0      2024-04-05 [1] CRAN (R 4.4.0)
#>    ps             1.8.0      2024-09-12 [1] CRAN (R 4.4.1)
#>    purrr          1.0.2      2023-08-10 [1] CRAN (R 4.4.0)
#>    quantreg       5.98       2024-05-26 [1] CRAN (R 4.4.0)
#>    R6             2.5.1      2021-08-19 [1] CRAN (R 4.4.0)
#>    Rcpp           1.0.13     2024-07-17 [1] CRAN (R 4.4.0)
#>    remotes        2.5.0      2024-03-17 [1] CRAN (R 4.4.0)
#>    reprex         2.1.1      2024-07-06 [1] CRAN (R 4.4.0)
#>    reticulate     1.38.0     2024-06-19 [1] CRAN (R 4.4.0)
#>    rlang          1.1.4      2024-06-04 [1] CRAN (R 4.4.0)
#>    rmarkdown      2.27       2024-05-17 [1] CRAN (R 4.4.0)
#>    rprojroot      2.0.4      2023-11-05 [1] CRAN (R 4.4.0)
#>    rstudioapi     0.16.0     2024-03-24 [1] CRAN (R 4.4.0)
#>    sessioninfo    1.2.2      2021-12-06 [1] CRAN (R 4.4.0)
#>    shiny          1.9.1      2024-08-01 [1] CRAN (R 4.4.0)
#>    SparseM        1.84-2     2024-07-17 [1] CRAN (R 4.4.0)
#>    stringi        1.8.4      2024-05-06 [1] CRAN (R 4.4.0)
#>    stringr        1.5.1      2023-11-14 [1] CRAN (R 4.4.0)
#>    survival       3.7-0      2024-06-05 [2] CRAN (R 4.4.1)
#>    tensorflow     2.16.0     2024-04-15 [1] CRAN (R 4.4.0)
#>    testthat     * 3.2.1.1    2024-04-14 [1] CRAN (R 4.4.0)
#>    tfautograph    0.3.2      2021-09-17 [1] CRAN (R 4.4.0)
#>    tfruns         1.5.3      2024-04-19 [1] CRAN (R 4.4.0)
#>    urlchecker     1.0.1.9000 2024-08-27 [1] Github (r-lib/urlchecker@ac38ea4)
#>    usethis        3.0.0      2024-07-29 [1] CRAN (R 4.4.0)
#>    utf8           1.2.4      2023-10-22 [1] CRAN (R 4.4.0)
#>    vctrs          0.6.5      2023-12-01 [1] CRAN (R 4.4.0)
#>    whisker        0.4.1      2022-12-05 [1] CRAN (R 4.4.0)
#>    withr          3.0.1      2024-07-31 [1] CRAN (R 4.4.0)
#>    xfun           0.46       2024-07-18 [1] CRAN (R 4.4.0)
#>    xtable         1.8-4      2019-04-21 [1] CRAN (R 4.4.0)
#>    yaml           2.3.10     2024-07-26 [1] CRAN (R 4.4.0)
#>    yesno          0.1.3      2024-07-26 [1] CRAN (R 4.4.0)
#> 
#>  [1] /Users/nick/Library/R/arm64/4.4/library
#>  [2] /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library
#> 
#>  P ── Loaded and on-disk path mismatch.
#> 
#> ─ Python configuration ───────────────────────────────────────────────────────
#>  python:         /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/bin/python
#>  libpython:      /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/libpython3.10.dylib
#>  pythonhome:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2:/Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2
#>  version:        3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:51:49) [Clang 16.0.6 ]
#>  numpy:          /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.10/site-packages/numpy
#>  numpy_version:  1.26.4
#>  tensorflow:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.10/site-packages/tensorflow
#>  
#>  NOTE: Python version was forced by use_python() function
#> 
#> ──────────────────────────────────────────────────────────────────────────────
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment