Skip to content

Instantly share code, notes, and snippets.

@junpenglao
Last active November 10, 2020 07:46
Show Gist options
  • Star 10 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save junpenglao/fe5e1b451c076cc7b4ca16acdd7d6472 to your computer and use it in GitHub Desktop.
Save junpenglao/fe5e1b451c076cc7b4ca16acdd7d6472 to your computer and use it in GitHub Desktop.
theano-jax test drive.ipynb
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@GarrettMooney
Copy link

This worked as a Dockerfile for me to reproduce this environment:

FROM tiangolo/python-machine-learning:cuda9.1-python3.7
RUN pip install --no-deps git+https://github.com/pymc-devs/pymc3
RUN git clone https://github.com/brandonwillard/Theano.git -b jax-linker &&\
  cd Theano &&\
  pip install -r requirements.txt
RUN pip install --upgrade pip && pip install --upgrade tf-nightly-gpu tfp-nightly jax jaxlib
RUN pip install xarray
RUN pip install arviz --ignore-installed certifi
RUN pip install dill fastprogress

@junpenglao
Copy link
Author

Awesome! thanks Garrett!

@twiecki
Copy link

twiecki commented Sep 27, 2020

@junpenglao would be interesting to try with the pymc3 grad function rather than the jax-derived one.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment