Skip to content

Instantly share code, notes, and snippets.

@wj-Mcat
Created October 18, 2022 07:32
Show Gist options
  • Save wj-Mcat/7f6cd7925284725ddbcd30167f5192cb to your computer and use it in GitHub Desktop.
Save wj-Mcat/7f6cd7925284725ddbcd30167f5192cb to your computer and use it in GitHub Desktop.
paddle 动转静示例代码
# 1. print the paddle version
from paddle import __version__
print("Paddle Version: " + __version__)
from paddlenlp.transformers import AutoModel
import paddle
# 2. load the dynamic model
input_ids = paddle.randint(10, 20, shape=[1, 20], dtype='int64')
model = AutoModel.from_pretrained("albert-base-v1")
dynamic_output = model(input_ids)[0]
# 3. save & load static model
inputs = [paddle.static.InputSpec(shape=[None, None], dtype="int64")]
model = paddle.jit.to_static(model, input_spec=inputs)
path = "sss/static_model"
paddle.jit.save(model, path)
model = paddle.jit.load(path)
static_output = model(input_ids)[0]
# 4. compare the dynamic output and static output
print(dynamic_output[:, 1:4, 1:4])
print("=======")
print(static_output[:, 1:4, 1:4])
assert paddle.allclose(dynamic_output[:, 1:4, 1:4], static_output[:, 1:4, 1:4], atol=1e-4)
@wj-Mcat
Copy link
Author

wj-Mcat commented Oct 18, 2022

Oh, good catch, using eval can output the same logit. @gongel thx.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment