Skip to content

Instantly share code, notes, and snippets.

Created December 23, 2018 12:05
Show Gist options
  • Save marta-sd/bd359e5047e7bc1abb8ba5bb65799e35 to your computer and use it in GitHub Desktop.
Save marta-sd/bd359e5047e7bc1abb8ba5bb65799e35 to your computer and use it in GitHub Desktop.
# before running this script clone Pafnucy's repository and create the environment:
# $ git clone
# $ cd pafnucy
# $ conda env create -f environment_gpu.yml
import numpy as np
import h5py
import tensorflow as tf
from import make_grid
# load Pafnucy
graph = tf.Graph()
with graph.as_default():
saver = tf.train.import_meta_graph('results/batch5-2017-06-05T07:58:47-best.meta')
# get placeholders for input, prediction and target
x = graph.get_tensor_by_name('input/structure:0')
y = graph.get_tensor_by_name('output/prediction:0')
t = graph.get_tensor_by_name('input/affinity:0')
keep_prob = graph.get_tensor_by_name('fully_connected/keep_prob:0')
train = graph.get_tensor_by_name('training/train:0')
# load some data
x_ = []
y_ = []
with h5py.File('tests/data/dataset/test_set.hdf', 'r') as f:
for name in ['1e66', '5c28']:
coords = (f[name][:, :3])
features = (f[name][:, 3:])
grid = make_grid(coords, features)
x_ = np.vstack(x_)
y_ = np.reshape(y_, (-1, 1))
print('target values:', y_)
# re-train Pafnucy
with tf.Session(graph=graph) as session:
saver.restore(session, 'results/batch5-2017-06-05T07:58:47-best')
print('predictions before training:',, feed_dict={x: x_, keep_prob: 1.0}))
for _ in range(10):, feed_dict={x: x_, t: y_, keep_prob: 1.0})
print('predictions after training:',, feed_dict={x: x_, keep_prob: 1.0})), 'pafnucy_retrained')
# load and use the new model
new_graph = tf.Graph()
with new_graph.as_default():
saver = tf.train.import_meta_graph('pafnucy_retrained.meta')
x = new_graph.get_tensor_by_name('input/structure:0')
y = new_graph.get_tensor_by_name('output/prediction:0')
keep_prob = new_graph.get_tensor_by_name('fully_connected/keep_prob:0')
with tf.Session(graph=new_graph) as session:
saver.restore(session, 'pafnucy_retrained')
print('predictions with loaded model:',, feed_dict={x: x_, keep_prob: 1.0}))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment