-
-
Save maciejgryka/38af38d120d3129cad641ea1a05d07a8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Mix.install([ | |
| :req, | |
| :finch, | |
| :explorer | |
| ]) | |
| alias Explorer.Series | |
| require Logger | |
| defmodule Dataset do | |
| def get_prompt(question, context) when is_binary(question) and is_binary(context) do | |
| [ | |
| %{ | |
| role: "system", | |
| content: "Your system prompt here" | |
| }, | |
| %{ | |
| role: "user", | |
| content: """ | |
| <context>#{context}</context> | |
| <question>#{question}</question> | |
| """ | |
| } | |
| ] | |
| end | |
| def load(path) do | |
| path | |
| |> File.stream!() | |
| |> Enum.map(&Jason.decode!/1) | |
| |> Enum.to_list() | |
| end | |
| end | |
| defmodule Stats do | |
| def compute_class_accuracy(results) do | |
| case Enum.filter(results, fn %{generated: g} -> is_binary(g) end) do | |
| [] -> | |
| %{label_accuracy: 0.0} | |
| valid_results -> | |
| {label_results} = | |
| Enum.reduce( | |
| valid_results, | |
| {[]}, | |
| fn %{generated: generated, expected: expected}, {label_results} -> | |
| label_correct = if generated == expected, do: 1, else: 0 | |
| {[label_correct | label_results]} | |
| end | |
| ) | |
| %{ | |
| label_accuracy: Enum.sum(label_results) / length(label_results) | |
| } | |
| end | |
| end | |
| def compute_timing(results) do | |
| durations = | |
| Enum.reduce(results, [], fn %{duration_ms: duration_ms}, duration_results -> | |
| [duration_ms | duration_results] | |
| end) | |
| durations_series = durations |> Series.from_list() | |
| p50 = Series.quantile(durations_series, 0.50) | |
| p95 = Series.quantile(durations_series, 0.95) | |
| p99 = Series.quantile(durations_series, 0.99) | |
| %{ | |
| p50_duration_s: p50 / 1000, | |
| p95_duration_s: p95 / 1000, | |
| p99_duration_s: p99 / 1000 | |
| } | |
| end | |
| def compute_tokens(results) do | |
| {input_tokens, output_tokens} = | |
| Enum.reduce(results, {0, 0}, fn result, {in_acc, out_acc} -> | |
| in_tok = Map.get(result, :input_tokens, 0) || 0 | |
| out_tok = Map.get(result, :output_tokens, 0) || 0 | |
| {in_acc + in_tok, out_acc + out_tok} | |
| end) | |
| %{ | |
| total_input_tokens: input_tokens, | |
| total_output_tokens: output_tokens, | |
| total_tokens: input_tokens + output_tokens | |
| } | |
| end | |
| end | |
| defmodule TrafficPattern do | |
| def build_phases( | |
| %{ | |
| baseline_rps: baseline, | |
| warmup_duration_s: warmup, | |
| cooldown_duration_s: cooldown, | |
| slope_rps_per_s: slope, | |
| hold_duration_s: hold, | |
| peaks: peaks | |
| } = _config | |
| ) do | |
| initial_phases = [{0, warmup, baseline, baseline}] | |
| initial_time = warmup | |
| {phases, _final_time} = | |
| peaks | |
| |> Enum.with_index() | |
| |> Enum.reduce({initial_phases, initial_time}, fn {peak, index}, {phases, t} -> | |
| ramp_duration = (peak - baseline) / slope | |
| is_last = index == length(peaks) - 1 | |
| spike_phases = [ | |
| {t, t + ramp_duration, baseline, peak}, | |
| {t + ramp_duration, t + ramp_duration + hold, peak, peak}, | |
| {t + ramp_duration + hold, t + 2 * ramp_duration + hold, peak, baseline} | |
| ] | |
| # Add cooldown after spike (except for the last one) | |
| spike_phases = | |
| if is_last do | |
| spike_phases | |
| else | |
| spike_phases ++ | |
| [ | |
| {t + 2 * ramp_duration + hold, t + 2 * ramp_duration + hold + cooldown, baseline, | |
| baseline} | |
| ] | |
| end | |
| next_time = | |
| if is_last, | |
| do: t + 2 * ramp_duration + hold, | |
| else: t + 2 * ramp_duration + hold + cooldown | |
| {phases ++ spike_phases, next_time} | |
| end) | |
| phases | |
| end | |
| def total_duration(phases) do | |
| phases |> List.last() |> elem(1) |> ceil() | |
| end | |
| def rps_at(phases, time_s) do | |
| phase = | |
| Enum.find(phases, fn phase -> | |
| start_s = elem(phase, 0) | |
| end_s = elem(phase, 1) | |
| time_s >= start_s and time_s < end_s | |
| end) | |
| case phase do | |
| {_start, _end, rps, rps} -> | |
| rps | |
| {start_s, end_s, from_rps, to_rps} when from_rps != to_rps -> | |
| progress = (time_s - start_s) / (end_s - start_s) | |
| from_rps + (to_rps - from_rps) * progress | |
| nil -> | |
| # Past the end, return baseline (shouldn't happen normally) | |
| 0 | |
| end | |
| end | |
| def generate_schedule(config) do | |
| phases = build_phases(config) | |
| duration = total_duration(phases) | |
| IO.puts( | |
| "Generated #{length(phases)} phases over #{duration} seconds. Peaks: #{inspect(config.peaks, charlists: :as_lists)} req/s" | |
| ) | |
| 0..(duration - 1) | |
| |> Enum.flat_map(fn second -> | |
| rps = rps_at(phases, second) | |
| generate_requests_in_second(second * 1000, round(rps)) | |
| end) | |
| |> Enum.sort() | |
| end | |
| defp generate_requests_in_second(_base_ms, 0), do: [] | |
| defp generate_requests_in_second(base_ms, count) when count > 0 do | |
| interval = 1000 / count | |
| Enum.map(0..(count - 1), fn i -> | |
| jitter = Enum.random(-50..50) | |
| max(0, round(base_ms + i * interval + jitter)) | |
| end) | |
| end | |
| end | |
| defmodule LoadTest do | |
| @provider :vertex_ai | |
| @providers %{ | |
| vertex_ai: %{ | |
| format: :openai, | |
| # Set GCP_PROJECT_ID and GCP_ACCESS_TOKEN env vars | |
| # export GCP_ACCESS_TOKEN=$(gcloud auth print-access-token) | |
| url: | |
| "https://aiplatform.googleapis.com/v1beta1/projects/#{System.get_env("GCP_PROJECT_ID")}/locations/global/endpoints/openapi/chat/completions", | |
| api_key_env: "GCP_ACCESS_TOKEN", | |
| model: "google/gemini-2.5-flash-lite" | |
| }, | |
| openai: %{ | |
| format: :openai, | |
| url: "https://api.openai.com/v1/chat/completions", | |
| api_key_env: "OPENAI_API_KEY", | |
| model: "gpt-5-nano" | |
| } | |
| } | |
| defp get_config do | |
| config = Map.fetch!(@providers, @provider) | |
| api_key = | |
| config[:api_key] || System.get_env(config[:api_key_env]) || | |
| raise "API key not found: set #{config[:api_key_env]} env var" | |
| {config.format, config.url, api_key, config.model} | |
| end | |
| def start_finch do | |
| {:ok, _} = | |
| Finch.start_link( | |
| name: LoadTestFinch, | |
| pools: %{ | |
| default: [ | |
| size: 1000, | |
| count: 100, | |
| conn_max_idle_time: :infinity, | |
| pool_max_idle_time: :infinity, | |
| conn_opts: [ | |
| transport_opts: [ | |
| timeout: 300_000 | |
| ] | |
| ] | |
| ] | |
| } | |
| ) | |
| end | |
| defp extract_answer(:openai, %Req.Response{body: body}) when is_map(body) do | |
| case body do | |
| %{"choices" => [%{"message" => %{"content" => content}} | _]} -> | |
| String.trim(content) | |
| _ -> | |
| # Error response or unexpected format | |
| nil | |
| end | |
| end | |
| # Vertex AI returns errors as a list | |
| defp extract_answer(:openai, %Req.Response{body: body}) do | |
| Logger.error("Unexpected response format: #{inspect(body)}") | |
| nil | |
| end | |
| defp extract_tokens(:openai, %Req.Response{body: body}) when is_map(body) do | |
| case body do | |
| %{"usage" => %{"prompt_tokens" => input, "completion_tokens" => output}} -> | |
| {input, output} | |
| _ -> | |
| {nil, nil} | |
| end | |
| end | |
| defp extract_tokens(:openai, _resp), do: {nil, nil} | |
| def run_rate_limited(requests_per_second, duration_seconds, report_interval_s \\ 5) do | |
| start_finch() | |
| {format, url, api_key, model} = get_config() | |
| dataset = Dataset.load("path/to/your/train.jsonl") | |
| total_requests = requests_per_second * duration_seconds | |
| delay_ms = (1000 / requests_per_second) |> round() | |
| IO.puts( | |
| "Making #{requests_per_second} req/s over #{duration_seconds}s (#{total_requests} requests, #{delay_ms}ms apart)" | |
| ) | |
| samples = dataset |> Stream.cycle() |> Enum.take(total_requests) | |
| {:ok, collector_pid} = Agent.start_link(fn -> {[], 0, 0} end) | |
| reporter = | |
| spawn_link(fn -> | |
| report_stats_periodically(collector_pid, report_interval_s * 1000, total_requests) | |
| end) | |
| tasks = | |
| samples | |
| |> Enum.with_index() | |
| |> Enum.map(fn {sample, index} -> | |
| jitter = 0 | |
| start_delay = max(index * delay_ms + jitter, 0) | |
| Task.async(fn -> | |
| Process.sleep(start_delay) | |
| messages = Dataset.get_prompt(sample["question"], sample["context"]) | |
| start_time = System.monotonic_time(:millisecond) | |
| with {:ok, resp} <- make_request(format, url, api_key, model, messages), | |
| generated <- extract_answer(format, resp) do | |
| end_time = System.monotonic_time(:millisecond) | |
| expected = sample |> Map.get("answer") | |
| result = | |
| %{ | |
| generated: generated, | |
| expected: expected, | |
| duration_ms: end_time - start_time | |
| } | |
| Agent.update(collector_pid, fn {current_batch_results, num_completed, num_failed} -> | |
| {[result | current_batch_results], num_completed + 1, num_failed} | |
| end) | |
| result | |
| else | |
| _reason -> | |
| Agent.update(collector_pid, fn {batch, num_completed, num_failed} -> | |
| {batch, num_completed, num_failed + 1} | |
| end) | |
| nil | |
| end | |
| end) | |
| end) | |
| results = tasks |> Enum.map(&Task.await(&1, 300_000)) |> Enum.filter(&(&1 != nil)) | |
| Process.exit(reporter, :normal) | |
| IO.puts("\n\n=== Final Results ===") | |
| IO.inspect(Stats.compute_class_accuracy(results), label: "Accuracy") | |
| IO.inspect(Stats.compute_timing(results), label: "Timing") | |
| results | |
| end | |
| def run_spike_test(traffic_config, report_interval_s \\ 5) do | |
| start_finch() | |
| {format, url, api_key, model} = get_config() | |
| # when measuring timing, use training data to have a larger dataset | |
| # in this case accuracy is not meaningful | |
| dataset = Dataset.load("path/to/your/train.jsonl") | |
| schedule = TrafficPattern.generate_schedule(traffic_config) | |
| total_requests = length(schedule) | |
| IO.puts("Will make #{total_requests} requests") | |
| samples = dataset |> Stream.cycle() |> Enum.take(total_requests) | |
| {:ok, collector_pid} = Agent.start_link(fn -> {[], 0, 0} end) | |
| reporter = | |
| spawn_link(fn -> | |
| report_stats_periodically(collector_pid, report_interval_s * 1000, total_requests) | |
| end) | |
| tasks = | |
| schedule | |
| |> Enum.with_index() | |
| |> Enum.map(fn {start_time_ms, index} -> | |
| sample = Enum.at(samples, index) | |
| Task.async(fn -> | |
| Process.sleep(start_time_ms) | |
| messages = Dataset.get_prompt(sample["question"], sample["context"]) | |
| start_time = System.monotonic_time(:millisecond) | |
| case make_request(format, url, api_key, model, messages) do | |
| {:ok, %Req.Response{status: status} = resp} when status >= 200 and status < 300 -> | |
| end_time = System.monotonic_time(:millisecond) | |
| generated = extract_answer(format, resp) | |
| {input_tokens, output_tokens} = extract_tokens(format, resp) | |
| expected = Map.get(sample, "answer") | |
| result = %{ | |
| generated: generated, | |
| expected: expected, | |
| duration_ms: end_time - start_time, | |
| input_tokens: input_tokens, | |
| output_tokens: output_tokens | |
| } | |
| Agent.update(collector_pid, fn {current_batch_results, num_completed, num_failed} -> | |
| {[result | current_batch_results], num_completed + 1, num_failed} | |
| end) | |
| result | |
| {:ok, %Req.Response{status: status, body: body}} -> | |
| Logger.error("HTTP #{status}: #{inspect(body)}") | |
| Agent.update(collector_pid, fn {batch, num_completed, num_failed} -> | |
| {batch, num_completed, num_failed + 1} | |
| end) | |
| nil | |
| {:error, %Req.TransportError{reason: :timeout}} -> | |
| Agent.update(collector_pid, fn {batch, num_completed, num_failed} -> | |
| {batch, num_completed, num_failed + 1} | |
| end) | |
| nil | |
| {:error, _reason} -> | |
| Agent.update(collector_pid, fn {batch, num_completed, num_failed} -> | |
| {batch, num_completed, num_failed + 1} | |
| end) | |
| nil | |
| end | |
| end) | |
| end) | |
| results = tasks |> Enum.map(&Task.await(&1, 600_000)) |> Enum.filter(&(&1 != nil)) | |
| Process.exit(reporter, :normal) | |
| IO.puts("\n\n=== Final Results ===") | |
| IO.inspect(Stats.compute_class_accuracy(results), label: "Accuracy") | |
| IO.inspect(Stats.compute_timing(results), label: "Timing") | |
| IO.inspect(Stats.compute_tokens(results), label: "Tokens") | |
| IO.puts("Completed: #{length(results)}/#{total_requests}") | |
| results | |
| end | |
| def run_accuracy_test(data_path, rps \\ 5, report_interval_s \\ 5) do | |
| start_finch() | |
| {format, url, api_key, model} = get_config() | |
| samples = Dataset.load(data_path) | |
| total_requests = length(samples) | |
| delay_ms = round(1000 / rps) | |
| IO.puts("Running accuracy test on #{total_requests} samples at #{rps} RPS") | |
| IO.puts("Estimated duration: #{round(total_requests / rps)} seconds") | |
| {:ok, collector_pid} = Agent.start_link(fn -> {[], 0, 0} end) | |
| reporter = | |
| spawn_link(fn -> | |
| report_stats_periodically(collector_pid, report_interval_s * 1000, total_requests) | |
| end) | |
| tasks = | |
| samples | |
| |> Enum.with_index() | |
| |> Enum.map(fn {sample, index} -> | |
| start_delay = index * delay_ms | |
| Task.async(fn -> | |
| Process.sleep(start_delay) | |
| messages = Dataset.get_prompt(sample["question"], sample["context"]) | |
| start_time = System.monotonic_time(:millisecond) | |
| case make_request(format, url, api_key, model, messages) do | |
| {:ok, %Req.Response{status: status} = resp} when status >= 200 and status < 300 -> | |
| end_time = System.monotonic_time(:millisecond) | |
| generated = extract_answer(format, resp) | |
| {input_tokens, output_tokens} = extract_tokens(format, resp) | |
| expected = Map.get(sample, "answer") | |
| result = %{ | |
| index: index, | |
| generated: generated, | |
| expected: expected, | |
| duration_ms: end_time - start_time, | |
| input_tokens: input_tokens, | |
| output_tokens: output_tokens | |
| } | |
| Agent.update(collector_pid, fn {current_batch_results, num_completed, num_failed} -> | |
| {[result | current_batch_results], num_completed + 1, num_failed} | |
| end) | |
| result | |
| {:ok, %Req.Response{status: status, body: body}} -> | |
| Logger.error("HTTP #{status} on sample #{index}: #{inspect(body)}") | |
| Agent.update(collector_pid, fn {batch, num_completed, num_failed} -> | |
| {batch, num_completed, num_failed + 1} | |
| end) | |
| nil | |
| {:error, %Req.TransportError{reason: :timeout}} -> | |
| Logger.warning("Timeout on sample #{index}") | |
| Agent.update(collector_pid, fn {batch, num_completed, num_failed} -> | |
| {batch, num_completed, num_failed + 1} | |
| end) | |
| nil | |
| {:error, reason} -> | |
| Logger.error("Error on sample #{index}: #{inspect(reason)}") | |
| Agent.update(collector_pid, fn {batch, num_completed, num_failed} -> | |
| {batch, num_completed, num_failed + 1} | |
| end) | |
| nil | |
| end | |
| end) | |
| end) | |
| results = tasks |> Enum.map(&Task.await(&1, 600_000)) |> Enum.filter(&(&1 != nil)) | |
| Process.exit(reporter, :normal) | |
| # Find misclassified samples | |
| misclassified = | |
| results | |
| |> Enum.filter(fn %{generated: g, expected: e} -> g != e end) | |
| |> Enum.sort_by(& &1.index) | |
| IO.puts("\n\n=== Final Results ===") | |
| IO.inspect(Stats.compute_class_accuracy(results), label: "Accuracy") | |
| IO.inspect(Stats.compute_timing(results), label: "Timing") | |
| IO.inspect(Stats.compute_tokens(results), label: "Tokens") | |
| IO.puts("Completed: #{length(results)}/#{total_requests}") | |
| IO.puts("Misclassified: #{length(misclassified)}") | |
| if length(misclassified) > 0 do | |
| IO.puts("\n=== Misclassified Samples (first 10) ===") | |
| misclassified | |
| |> Enum.take(10) | |
| |> Enum.each(fn %{index: i, generated: g, expected: e} -> | |
| IO.puts(" ##{i}: expected \"#{e}\", got \"#{g}\"") | |
| end) | |
| end | |
| {results, misclassified} | |
| end | |
| defp format_float(duration) do | |
| :erlang.float_to_binary(duration, decimals: 2) | |
| end | |
| defp report_stats_periodically(collector_pid, interval_ms, total_requests) do | |
| Process.sleep(interval_ms) | |
| {current_batch_results, num_completed, num_failed} = Agent.get(collector_pid, & &1) | |
| num_completed_batch = length(current_batch_results) | |
| batch_rps = 1000 * num_completed_batch / interval_ms | |
| if num_completed_batch > 0 do | |
| stats = Stats.compute_timing(current_batch_results) | |
| p50 = format_float(stats.p50_duration_s) | |
| p95 = format_float(stats.p95_duration_s) | |
| p99 = format_float(stats.p99_duration_s) | |
| timestamp = DateTime.utc_now() |> DateTime.to_iso8601() |> String.split(".") |> List.first() | |
| IO.puts( | |
| "#{timestamp} p50: #{p50}s p95: #{p95}s p99: #{p99}s (#{num_completed}/#{total_requests} completed at #{batch_rps} RPS, #{num_failed} failed)" | |
| ) | |
| Agent.update(collector_pid, fn {_, _, _} -> {[], num_completed, num_failed} end) | |
| end | |
| if num_completed < total_requests do | |
| report_stats_periodically(collector_pid, interval_ms, total_requests) | |
| end | |
| end | |
| def make_request(:openai, url, api_key, model, messages) do | |
| Req.post(url, | |
| auth: {:bearer, api_key}, | |
| finch: LoadTestFinch, | |
| receive_timeout: 8_000, | |
| json: %{ | |
| model: model, | |
| messages: messages, | |
| temperature: 0.0, | |
| max_tokens: 50 | |
| } | |
| ) | |
| end | |
| end | |
| # ============================================================================= | |
| # Example usage - uncomment the test you want to run | |
| # ============================================================================= | |
| # Constant rate test: | |
| # LoadTest.run_rate_limited(1, 5, 1) | |
| # Spike test with traffic ramping: | |
| # LoadTest.run_spike_test( | |
| # %{ | |
| # baseline_rps: 1, | |
| # warmup_duration_s: 30, | |
| # cooldown_duration_s: 30, | |
| # slope_rps_per_s: 0.1, | |
| # hold_duration_s: 60, | |
| # peaks: [6] | |
| # }, | |
| # 5 | |
| # ) | |
| # Accuracy test on test data: | |
| # LoadTest.run_accuracy_test("path/to/your/test.jsonl", 5) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment