Skip to content

Instantly share code, notes, and snippets.

@albertz
Created October 1, 2021 18:56
Show Gist options
  • Save albertz/21e00a500e41eb0c8d27a8519e763f0e to your computer and use it in GitHub Desktop.
Save albertz/21e00a500e41eb0c8d27a8519e763f0e to your computer and use it in GitHub Desktop.
from returnn.config import Config
from returnn.tf.engine import Engine
import sys
sys.path.append("tests")
from test_TFNetworkLayer import make_feed_dict
n_in = 40
model_filename = "test-703.model.001"
def make_net_dict():
return {
"output": {
"class": "conv", "from": "data",
"filter_size": [3,3], "padding": "same",
"n_out": 32, "activation": None, "with_bias": True
}
}
config = Config({
"extern_data": {"data": {"shape": (None, n_in, 1)}},
"task": "train",
"network": make_net_dict(),
})
engine = Engine(config=config)
engine.init_train_from_config()
engine.save_model(model_filename)
config.typed_dict["extern_data"]["data"]["shape"] = (None, 1, n_in)
config.typed_dict["task"] = "eval"
config.typed_dict["load"] = model_filename
engine = Engine(config=config)
engine.init_network_from_config()
net = engine.network
out = net.get_layer("output").output
engine.tf_session.run(out.placeholder, feed_dict=make_feed_dict(net.extern_data))
@albertz
Copy link
Author

albertz commented Oct 1, 2021

So I reported this here: tensorflow/tensorflow#52223

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment