Created
April 15, 2023 12:08
-
-
Save bogdad/80b18770ae5bc75d8563f00e5fbb47c9 to your computer and use it in GitHub Desktop.
benchmark_threads.txt for mac
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import subprocess | |
import matplotlib.pyplot as plt | |
import re | |
# Defining the command template | |
cmd = "./build/bin/main \ | |
--seed 147369852 \ | |
--threads {threads} \ | |
--n_predict 64 \ | |
--model ./models/7B/ggml-model-q4_0.bin \ | |
--top_k 40 \ | |
--top_p 0.9 \ | |
--temp 0.5 \ | |
--repeat_last_n 64 \ | |
--repeat_penalty 1.1 \ | |
-p \"Write a funny joke:\"" | |
# Defining the range of threads to loop over | |
min_threads = 4 | |
max_threads = 16 | |
step = 2 | |
# Defining the number of runs for each thread cmd evaluation | |
n_runs = 5 | |
# Initializing the lists to store the results | |
threads_list = [] | |
token_time_list = [] | |
for threads in range(min_threads, max_threads + 1, step): | |
print(f"Running with {threads} threads...") | |
token_times = [] | |
for run in range(n_runs): | |
result = subprocess.run(cmd.format(threads=threads), stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True) | |
output = result.stdout.decode() | |
print(output) | |
# Extracting the token time using regular expression | |
token_time = float(re.search(r"\s+(\d+\.\d+) ms per token", output).group(1)) | |
print(f"\t {threads} threads | run {run+1}/{n_runs} | current token time {round(token_time, 2)} ms") | |
token_times.append(token_time) | |
# Get the average token time for the current number of threads | |
avg_token_time = sum(token_times) / len(token_times) | |
token_time_list.append(avg_token_time) | |
threads_list.append(threads) | |
# Plot the result | |
plt.plot(threads_list, token_time_list) | |
plt.xlabel("Number of threads") | |
plt.ylabel("Token time (ms)") | |
plt.title("Token time vs Number of threads") | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment