/nutpie-ffi.py Secret
Created
August 14, 2024 16:15
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# set-up: compile C, load my function | |
import subprocess | |
subprocess.run(["gcc", "-shared", "-o", "libnutpietest.so", "test.c"]) | |
import ctypes | |
import numpy as np | |
lib = ctypes.CDLL("./libnutpietest.so") | |
g = lib.g | |
g.argtypes = [np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, flags='C_CONTIGUOUS')] | |
g.restype = ctypes.c_double | |
# build nutpie wrapper for my code: | |
def make_logp_func(): | |
# returns the log density function, which accepts a vector and returns lp, grad_lp | |
# the unused arguments can be "shared data" for e.g. closures, etc. | |
# See https://github.com/pymc-devs/nutpie/blob/cc34b5810a14bbb4115f02d360c35967092989ba/python/nutpie/compile_pymc.py#L320 | |
def logp(x, **unused): | |
print("hi from python") | |
# standard normal, with g implemented in C | |
return float(g(x)), -x | |
return logp | |
def make_expand_func(*unused): | |
# This can be used to "expand" the output of the logp function, | |
# e.g. I think this is how generated quantities are implemented | |
def expand(x, **unused): | |
return {"y":x} | |
return expand | |
import nutpie | |
from nutpie.compiled_pyfunc import from_pyfunc | |
# arguments are: dims, make_logp, make_expand, dtypes, shapes, names | |
# there are more, unused by this example | |
model = from_pyfunc(1, make_logp_func, make_expand_func, [np.float64], [(1,)], ["y"],) | |
# sample | |
fit = nutpie.sample(model) | |
print(fit.posterior.y.mean()) | |
print(fit.posterior.y.std()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdio.h> | |
double g(double *y) | |
{ | |
printf("hi from C\n"); | |
return -0.5 * (*y) * (*y); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment