Skip to content

Instantly share code, notes, and snippets.

@maciejgryka
Created January 28, 2026 15:54
Show Gist options
  • Select an option

  • Save maciejgryka/38af38d120d3129cad641ea1a05d07a8 to your computer and use it in GitHub Desktop.

Select an option

Save maciejgryka/38af38d120d3129cad641ea1a05d07a8 to your computer and use it in GitHub Desktop.
Mix.install([
:req,
:finch,
:explorer
])
alias Explorer.Series
require Logger
defmodule Dataset do
def get_prompt(question, context) when is_binary(question) and is_binary(context) do
[
%{
role: "system",
content: "Your system prompt here"
},
%{
role: "user",
content: """
<context>#{context}</context>
<question>#{question}</question>
"""
}
]
end
def load(path) do
path
|> File.stream!()
|> Enum.map(&Jason.decode!/1)
|> Enum.to_list()
end
end
defmodule Stats do
def compute_class_accuracy(results) do
case Enum.filter(results, fn %{generated: g} -> is_binary(g) end) do
[] ->
%{label_accuracy: 0.0}
valid_results ->
{label_results} =
Enum.reduce(
valid_results,
{[]},
fn %{generated: generated, expected: expected}, {label_results} ->
label_correct = if generated == expected, do: 1, else: 0
{[label_correct | label_results]}
end
)
%{
label_accuracy: Enum.sum(label_results) / length(label_results)
}
end
end
def compute_timing(results) do
durations =
Enum.reduce(results, [], fn %{duration_ms: duration_ms}, duration_results ->
[duration_ms | duration_results]
end)
durations_series = durations |> Series.from_list()
p50 = Series.quantile(durations_series, 0.50)
p95 = Series.quantile(durations_series, 0.95)
p99 = Series.quantile(durations_series, 0.99)
%{
p50_duration_s: p50 / 1000,
p95_duration_s: p95 / 1000,
p99_duration_s: p99 / 1000
}
end
def compute_tokens(results) do
{input_tokens, output_tokens} =
Enum.reduce(results, {0, 0}, fn result, {in_acc, out_acc} ->
in_tok = Map.get(result, :input_tokens, 0) || 0
out_tok = Map.get(result, :output_tokens, 0) || 0
{in_acc + in_tok, out_acc + out_tok}
end)
%{
total_input_tokens: input_tokens,
total_output_tokens: output_tokens,
total_tokens: input_tokens + output_tokens
}
end
end
defmodule TrafficPattern do
def build_phases(
%{
baseline_rps: baseline,
warmup_duration_s: warmup,
cooldown_duration_s: cooldown,
slope_rps_per_s: slope,
hold_duration_s: hold,
peaks: peaks
} = _config
) do
initial_phases = [{0, warmup, baseline, baseline}]
initial_time = warmup
{phases, _final_time} =
peaks
|> Enum.with_index()
|> Enum.reduce({initial_phases, initial_time}, fn {peak, index}, {phases, t} ->
ramp_duration = (peak - baseline) / slope
is_last = index == length(peaks) - 1
spike_phases = [
{t, t + ramp_duration, baseline, peak},
{t + ramp_duration, t + ramp_duration + hold, peak, peak},
{t + ramp_duration + hold, t + 2 * ramp_duration + hold, peak, baseline}
]
# Add cooldown after spike (except for the last one)
spike_phases =
if is_last do
spike_phases
else
spike_phases ++
[
{t + 2 * ramp_duration + hold, t + 2 * ramp_duration + hold + cooldown, baseline,
baseline}
]
end
next_time =
if is_last,
do: t + 2 * ramp_duration + hold,
else: t + 2 * ramp_duration + hold + cooldown
{phases ++ spike_phases, next_time}
end)
phases
end
def total_duration(phases) do
phases |> List.last() |> elem(1) |> ceil()
end
def rps_at(phases, time_s) do
phase =
Enum.find(phases, fn phase ->
start_s = elem(phase, 0)
end_s = elem(phase, 1)
time_s >= start_s and time_s < end_s
end)
case phase do
{_start, _end, rps, rps} ->
rps
{start_s, end_s, from_rps, to_rps} when from_rps != to_rps ->
progress = (time_s - start_s) / (end_s - start_s)
from_rps + (to_rps - from_rps) * progress
nil ->
# Past the end, return baseline (shouldn't happen normally)
0
end
end
def generate_schedule(config) do
phases = build_phases(config)
duration = total_duration(phases)
IO.puts(
"Generated #{length(phases)} phases over #{duration} seconds. Peaks: #{inspect(config.peaks, charlists: :as_lists)} req/s"
)
0..(duration - 1)
|> Enum.flat_map(fn second ->
rps = rps_at(phases, second)
generate_requests_in_second(second * 1000, round(rps))
end)
|> Enum.sort()
end
defp generate_requests_in_second(_base_ms, 0), do: []
defp generate_requests_in_second(base_ms, count) when count > 0 do
interval = 1000 / count
Enum.map(0..(count - 1), fn i ->
jitter = Enum.random(-50..50)
max(0, round(base_ms + i * interval + jitter))
end)
end
end
defmodule LoadTest do
@provider :vertex_ai
@providers %{
vertex_ai: %{
format: :openai,
# Set GCP_PROJECT_ID and GCP_ACCESS_TOKEN env vars
# export GCP_ACCESS_TOKEN=$(gcloud auth print-access-token)
url:
"https://aiplatform.googleapis.com/v1beta1/projects/#{System.get_env("GCP_PROJECT_ID")}/locations/global/endpoints/openapi/chat/completions",
api_key_env: "GCP_ACCESS_TOKEN",
model: "google/gemini-2.5-flash-lite"
},
openai: %{
format: :openai,
url: "https://api.openai.com/v1/chat/completions",
api_key_env: "OPENAI_API_KEY",
model: "gpt-5-nano"
}
}
defp get_config do
config = Map.fetch!(@providers, @provider)
api_key =
config[:api_key] || System.get_env(config[:api_key_env]) ||
raise "API key not found: set #{config[:api_key_env]} env var"
{config.format, config.url, api_key, config.model}
end
def start_finch do
{:ok, _} =
Finch.start_link(
name: LoadTestFinch,
pools: %{
default: [
size: 1000,
count: 100,
conn_max_idle_time: :infinity,
pool_max_idle_time: :infinity,
conn_opts: [
transport_opts: [
timeout: 300_000
]
]
]
}
)
end
defp extract_answer(:openai, %Req.Response{body: body}) when is_map(body) do
case body do
%{"choices" => [%{"message" => %{"content" => content}} | _]} ->
String.trim(content)
_ ->
# Error response or unexpected format
nil
end
end
# Vertex AI returns errors as a list
defp extract_answer(:openai, %Req.Response{body: body}) do
Logger.error("Unexpected response format: #{inspect(body)}")
nil
end
defp extract_tokens(:openai, %Req.Response{body: body}) when is_map(body) do
case body do
%{"usage" => %{"prompt_tokens" => input, "completion_tokens" => output}} ->
{input, output}
_ ->
{nil, nil}
end
end
defp extract_tokens(:openai, _resp), do: {nil, nil}
def run_rate_limited(requests_per_second, duration_seconds, report_interval_s \\ 5) do
start_finch()
{format, url, api_key, model} = get_config()
dataset = Dataset.load("path/to/your/train.jsonl")
total_requests = requests_per_second * duration_seconds
delay_ms = (1000 / requests_per_second) |> round()
IO.puts(
"Making #{requests_per_second} req/s over #{duration_seconds}s (#{total_requests} requests, #{delay_ms}ms apart)"
)
samples = dataset |> Stream.cycle() |> Enum.take(total_requests)
{:ok, collector_pid} = Agent.start_link(fn -> {[], 0, 0} end)
reporter =
spawn_link(fn ->
report_stats_periodically(collector_pid, report_interval_s * 1000, total_requests)
end)
tasks =
samples
|> Enum.with_index()
|> Enum.map(fn {sample, index} ->
jitter = 0
start_delay = max(index * delay_ms + jitter, 0)
Task.async(fn ->
Process.sleep(start_delay)
messages = Dataset.get_prompt(sample["question"], sample["context"])
start_time = System.monotonic_time(:millisecond)
with {:ok, resp} <- make_request(format, url, api_key, model, messages),
generated <- extract_answer(format, resp) do
end_time = System.monotonic_time(:millisecond)
expected = sample |> Map.get("answer")
result =
%{
generated: generated,
expected: expected,
duration_ms: end_time - start_time
}
Agent.update(collector_pid, fn {current_batch_results, num_completed, num_failed} ->
{[result | current_batch_results], num_completed + 1, num_failed}
end)
result
else
_reason ->
Agent.update(collector_pid, fn {batch, num_completed, num_failed} ->
{batch, num_completed, num_failed + 1}
end)
nil
end
end)
end)
results = tasks |> Enum.map(&Task.await(&1, 300_000)) |> Enum.filter(&(&1 != nil))
Process.exit(reporter, :normal)
IO.puts("\n\n=== Final Results ===")
IO.inspect(Stats.compute_class_accuracy(results), label: "Accuracy")
IO.inspect(Stats.compute_timing(results), label: "Timing")
results
end
def run_spike_test(traffic_config, report_interval_s \\ 5) do
start_finch()
{format, url, api_key, model} = get_config()
# when measuring timing, use training data to have a larger dataset
# in this case accuracy is not meaningful
dataset = Dataset.load("path/to/your/train.jsonl")
schedule = TrafficPattern.generate_schedule(traffic_config)
total_requests = length(schedule)
IO.puts("Will make #{total_requests} requests")
samples = dataset |> Stream.cycle() |> Enum.take(total_requests)
{:ok, collector_pid} = Agent.start_link(fn -> {[], 0, 0} end)
reporter =
spawn_link(fn ->
report_stats_periodically(collector_pid, report_interval_s * 1000, total_requests)
end)
tasks =
schedule
|> Enum.with_index()
|> Enum.map(fn {start_time_ms, index} ->
sample = Enum.at(samples, index)
Task.async(fn ->
Process.sleep(start_time_ms)
messages = Dataset.get_prompt(sample["question"], sample["context"])
start_time = System.monotonic_time(:millisecond)
case make_request(format, url, api_key, model, messages) do
{:ok, %Req.Response{status: status} = resp} when status >= 200 and status < 300 ->
end_time = System.monotonic_time(:millisecond)
generated = extract_answer(format, resp)
{input_tokens, output_tokens} = extract_tokens(format, resp)
expected = Map.get(sample, "answer")
result = %{
generated: generated,
expected: expected,
duration_ms: end_time - start_time,
input_tokens: input_tokens,
output_tokens: output_tokens
}
Agent.update(collector_pid, fn {current_batch_results, num_completed, num_failed} ->
{[result | current_batch_results], num_completed + 1, num_failed}
end)
result
{:ok, %Req.Response{status: status, body: body}} ->
Logger.error("HTTP #{status}: #{inspect(body)}")
Agent.update(collector_pid, fn {batch, num_completed, num_failed} ->
{batch, num_completed, num_failed + 1}
end)
nil
{:error, %Req.TransportError{reason: :timeout}} ->
Agent.update(collector_pid, fn {batch, num_completed, num_failed} ->
{batch, num_completed, num_failed + 1}
end)
nil
{:error, _reason} ->
Agent.update(collector_pid, fn {batch, num_completed, num_failed} ->
{batch, num_completed, num_failed + 1}
end)
nil
end
end)
end)
results = tasks |> Enum.map(&Task.await(&1, 600_000)) |> Enum.filter(&(&1 != nil))
Process.exit(reporter, :normal)
IO.puts("\n\n=== Final Results ===")
IO.inspect(Stats.compute_class_accuracy(results), label: "Accuracy")
IO.inspect(Stats.compute_timing(results), label: "Timing")
IO.inspect(Stats.compute_tokens(results), label: "Tokens")
IO.puts("Completed: #{length(results)}/#{total_requests}")
results
end
def run_accuracy_test(data_path, rps \\ 5, report_interval_s \\ 5) do
start_finch()
{format, url, api_key, model} = get_config()
samples = Dataset.load(data_path)
total_requests = length(samples)
delay_ms = round(1000 / rps)
IO.puts("Running accuracy test on #{total_requests} samples at #{rps} RPS")
IO.puts("Estimated duration: #{round(total_requests / rps)} seconds")
{:ok, collector_pid} = Agent.start_link(fn -> {[], 0, 0} end)
reporter =
spawn_link(fn ->
report_stats_periodically(collector_pid, report_interval_s * 1000, total_requests)
end)
tasks =
samples
|> Enum.with_index()
|> Enum.map(fn {sample, index} ->
start_delay = index * delay_ms
Task.async(fn ->
Process.sleep(start_delay)
messages = Dataset.get_prompt(sample["question"], sample["context"])
start_time = System.monotonic_time(:millisecond)
case make_request(format, url, api_key, model, messages) do
{:ok, %Req.Response{status: status} = resp} when status >= 200 and status < 300 ->
end_time = System.monotonic_time(:millisecond)
generated = extract_answer(format, resp)
{input_tokens, output_tokens} = extract_tokens(format, resp)
expected = Map.get(sample, "answer")
result = %{
index: index,
generated: generated,
expected: expected,
duration_ms: end_time - start_time,
input_tokens: input_tokens,
output_tokens: output_tokens
}
Agent.update(collector_pid, fn {current_batch_results, num_completed, num_failed} ->
{[result | current_batch_results], num_completed + 1, num_failed}
end)
result
{:ok, %Req.Response{status: status, body: body}} ->
Logger.error("HTTP #{status} on sample #{index}: #{inspect(body)}")
Agent.update(collector_pid, fn {batch, num_completed, num_failed} ->
{batch, num_completed, num_failed + 1}
end)
nil
{:error, %Req.TransportError{reason: :timeout}} ->
Logger.warning("Timeout on sample #{index}")
Agent.update(collector_pid, fn {batch, num_completed, num_failed} ->
{batch, num_completed, num_failed + 1}
end)
nil
{:error, reason} ->
Logger.error("Error on sample #{index}: #{inspect(reason)}")
Agent.update(collector_pid, fn {batch, num_completed, num_failed} ->
{batch, num_completed, num_failed + 1}
end)
nil
end
end)
end)
results = tasks |> Enum.map(&Task.await(&1, 600_000)) |> Enum.filter(&(&1 != nil))
Process.exit(reporter, :normal)
# Find misclassified samples
misclassified =
results
|> Enum.filter(fn %{generated: g, expected: e} -> g != e end)
|> Enum.sort_by(& &1.index)
IO.puts("\n\n=== Final Results ===")
IO.inspect(Stats.compute_class_accuracy(results), label: "Accuracy")
IO.inspect(Stats.compute_timing(results), label: "Timing")
IO.inspect(Stats.compute_tokens(results), label: "Tokens")
IO.puts("Completed: #{length(results)}/#{total_requests}")
IO.puts("Misclassified: #{length(misclassified)}")
if length(misclassified) > 0 do
IO.puts("\n=== Misclassified Samples (first 10) ===")
misclassified
|> Enum.take(10)
|> Enum.each(fn %{index: i, generated: g, expected: e} ->
IO.puts(" ##{i}: expected \"#{e}\", got \"#{g}\"")
end)
end
{results, misclassified}
end
defp format_float(duration) do
:erlang.float_to_binary(duration, decimals: 2)
end
defp report_stats_periodically(collector_pid, interval_ms, total_requests) do
Process.sleep(interval_ms)
{current_batch_results, num_completed, num_failed} = Agent.get(collector_pid, & &1)
num_completed_batch = length(current_batch_results)
batch_rps = 1000 * num_completed_batch / interval_ms
if num_completed_batch > 0 do
stats = Stats.compute_timing(current_batch_results)
p50 = format_float(stats.p50_duration_s)
p95 = format_float(stats.p95_duration_s)
p99 = format_float(stats.p99_duration_s)
timestamp = DateTime.utc_now() |> DateTime.to_iso8601() |> String.split(".") |> List.first()
IO.puts(
"#{timestamp} p50: #{p50}s p95: #{p95}s p99: #{p99}s (#{num_completed}/#{total_requests} completed at #{batch_rps} RPS, #{num_failed} failed)"
)
Agent.update(collector_pid, fn {_, _, _} -> {[], num_completed, num_failed} end)
end
if num_completed < total_requests do
report_stats_periodically(collector_pid, interval_ms, total_requests)
end
end
def make_request(:openai, url, api_key, model, messages) do
Req.post(url,
auth: {:bearer, api_key},
finch: LoadTestFinch,
receive_timeout: 8_000,
json: %{
model: model,
messages: messages,
temperature: 0.0,
max_tokens: 50
}
)
end
end
# =============================================================================
# Example usage - uncomment the test you want to run
# =============================================================================
# Constant rate test:
# LoadTest.run_rate_limited(1, 5, 1)
# Spike test with traffic ramping:
# LoadTest.run_spike_test(
# %{
# baseline_rps: 1,
# warmup_duration_s: 30,
# cooldown_duration_s: 30,
# slope_rps_per_s: 0.1,
# hold_duration_s: 60,
# peaks: [6]
# },
# 5
# )
# Accuracy test on test data:
# LoadTest.run_accuracy_test("path/to/your/test.jsonl", 5)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment