Last active
October 18, 2022 04:58
-
-
Save wj-Mcat/57dc5505d58acb6c8a6a998460da54bc to your computer and use it in GitHub Desktop.
performance of `position_ids`
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#%% | |
import paddle | |
from paddle import nn | |
import time | |
import pandas as pd | |
from tqdm import tqdm, trange | |
def perf_position_ids_cache(epoch: int = 1000, max_position_embedding: int = 512, hidden_size: int = 768): | |
position_ids_cache = paddle.expand(paddle.arange(max_position_embedding, dtype="int64"), | |
shape=[1, -1]) | |
start = time.perf_counter() | |
for _ in range(epoch): | |
_ = position_ids_cache[: max_position_embedding // 2] | |
cost = (time.perf_counter() - start) / 50 * 1000 | |
return cost | |
def perf_position_ids(epoch: int = 1000, max_position_embedding: int = 512, hidden_size: int = 768): | |
seq_len = max_position_embedding // 2 | |
input_ids = paddle.randn([4, seq_len]) | |
start = time.perf_counter() | |
for _ in range(epoch): | |
ones = paddle.ones_like(input_ids, dtype="int64") | |
seq_length = paddle.cumsum(ones, axis=-1) | |
position_ids = seq_length - ones | |
cost = (time.perf_counter() - start) / 50 * 1000 | |
return cost | |
#%% | |
def run_function_perf(): | |
# epoch | |
series = [] | |
for i in trange(100, 1000, 100): | |
cache_cost = perf_position_ids_cache(epoch=i) | |
no_cache_cost = perf_position_ids(epoch=i) | |
series.append( | |
{ | |
"cache": cache_cost, | |
"no_cache": no_cache_cost, | |
"speed up": (no_cache_cost / cache_cost), | |
"epochs": i | |
} | |
) | |
table = pd.DataFrame(series) | |
# table.plot(x="epochs", y=['cache', "no_cache"]) | |
table.plot(x="epochs", y="speed up") | |
# print(table) | |
# # epoch | |
# series = [] | |
# for i in trange(128, 1024, 10): | |
# series.append( | |
# { | |
# "cache": perf_position_ids_cache(max_position_embedding=i), | |
# "no_cache": perf_position_ids(max_position_embedding=i) | |
# } | |
# ) | |
# pd.DataFrame(series).to_excel("./sss/perf-max-posi.xlsx", index=False) | |
run_function_perf() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment