-
-
Save WardBrian/ad6c22aa0ea32a92906a6cdc18720eca to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# set-up: compile C, load my function | |
import subprocess | |
subprocess.run(["gcc", "-shared", "-o", "libnutpietest.so", "test.c"]) | |
import ctypes | |
import numpy as np | |
lib = ctypes.CDLL("./libnutpietest.so") | |
g = lib.g | |
g.argtypes = [np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, flags='C_CONTIGUOUS')] | |
g.restype = ctypes.c_double | |
# build nutpie wrapper for my code: | |
def make_logp_func(): | |
# returns the log density function, which accepts a vector and returns lp, grad_lp | |
# the unused arguments can be "shared data" for e.g. closures, etc. | |
# See https://github.com/pymc-devs/nutpie/blob/cc34b5810a14bbb4115f02d360c35967092989ba/python/nutpie/compile_pymc.py#L320 | |
def logp(x, **unused): | |
print("hi from python") | |
# standard normal, with g implemented in C | |
return float(g(x)), -x | |
return logp | |
def make_expand_func(*unused): | |
# This can be used to "expand" the output of the logp function, | |
# e.g. I think this is how generated quantities are implemented | |
def expand(x, **unused): | |
return {"y":x} | |
return expand | |
import nutpie | |
from nutpie.compiled_pyfunc import from_pyfunc | |
# arguments are: dims, make_logp, make_expand, dtypes, shapes, names | |
# there are more, unused by this example | |
model = from_pyfunc(1, make_logp_func, make_expand_func, [np.float64], [(1,)], ["y"],) | |
# sample | |
fit = nutpie.sample(model) | |
print(fit.posterior.y.mean()) | |
print(fit.posterior.y.std()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdio.h> | |
double g(double *y) | |
{ | |
printf("hi from C\n"); | |
return -0.5 * (*y) * (*y); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment