Skip to content

Instantly share code, notes, and snippets.

@justanotheratom
Last active December 10, 2025 04:10
Show Gist options
  • Select an option

  • Save justanotheratom/20893a4c04eddb47612247ea4e343c25 to your computer and use it in GitHub Desktop.

Select an option

Save justanotheratom/20893a4c04eddb47612247ea4e343c25 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
"""
Per-Module Usage Tracker for DSPy
Tracks token usage for individual sub-modules in multi-module DSPy programs.
Quick Start:
1. pip install dspy
2. export OPENAI_API_KEY="your-key"
3. Run: python per_module_usage_tracker.py
"""
import warnings
import logging
# Suppress warnings before importing modules that may emit them
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", message=".*Pydantic.*")
import dspy
from dspy.utils.callback import BaseCallback
from collections import defaultdict
class PerModuleUsageTracker(BaseCallback):
"""Tracks token usage for each sub-module separately with descriptive names."""
def __init__(self, top_level_module=None):
self.module_usage = defaultdict(dict) # module_path -> usage dict
self.module_trackers = {} # call_id -> (module_path, initial_usage_state)
self.top_level_module = top_level_module # Optional: root module for better naming
def _find_module_attribute_name(self, module_instance, parent_module):
"""Find the attribute name of a module instance in its parent module."""
if parent_module is None:
return None
# Check direct attributes
for attr_name, attr_value in parent_module.__dict__.items():
if attr_value is module_instance:
return attr_name
# Check nested structures (lists, dicts, nested modules)
# This handles cases like parent.submodule.inner
for attr_name, attr_value in parent_module.__dict__.items():
if isinstance(attr_value, dspy.Module):
# Recursively check nested modules
nested_name = self._find_module_attribute_name(module_instance, attr_value)
if nested_name:
return f"{attr_name}.{nested_name}"
elif isinstance(attr_value, (list, tuple)):
for idx, item in enumerate(attr_value):
if item is module_instance:
return f"{attr_name}[{idx}]"
elif isinstance(attr_value, dict):
for key, item in attr_value.items():
if item is module_instance:
return f"{attr_name}['{key}']"
return None
def _has_direct_predictors(self, instance):
"""Check if a module has direct Predict instances as attributes (not recursively)."""
from dspy.predict.predict import Predict
try:
# Check only direct attributes, not recursive ones
for attr_name, attr_value in instance.__dict__.items():
if isinstance(attr_value, Predict):
return True
return False
except Exception:
return False
def _build_module_path(self, instance):
"""Build a descriptive path for a module instance."""
from dspy.predict.predict import Predict
# Skip tracking internal Predict instances (they're tracked via their parent modules)
if isinstance(instance, Predict):
return None
# Skip the top-level module itself (we only track sub-modules)
if self.top_level_module is not None and instance is self.top_level_module:
return None
# Skip container modules that don't have direct predictors
# (they're just wrappers around other modules)
if not self._has_direct_predictors(instance):
return None
caller_modules = dspy.settings.caller_modules or []
# Filter out Predict instances and container modules without direct predictors
# Keep only modules that have direct predictors (or are the top-level module)
filtered_caller_modules = []
for m in caller_modules:
if isinstance(m, Predict):
continue
# Keep top-level module even if it doesn't have direct predictors
if self.top_level_module and m is self.top_level_module:
filtered_caller_modules.append(m)
# Keep modules that have direct predictors
elif self._has_direct_predictors(m):
filtered_caller_modules.append(m)
caller_modules = filtered_caller_modules
# If this is the top-level module, return its class name
if len(caller_modules) == 0:
return instance.__class__.__name__
# Build path from the module hierarchy
path_parts = []
# Determine the top-level module name
if self.top_level_module:
top_module = self.top_level_module
top_name = top_module.__class__.__name__
elif len(caller_modules) > 0:
top_module = caller_modules[0]
top_name = top_module.__class__.__name__
else:
top_module = None
top_name = instance.__class__.__name__
# Check if first caller_module is the top-level module (to avoid duplication)
start_idx = 0
if caller_modules and self.top_level_module and caller_modules[0] is self.top_level_module:
# Skip the first element since it's the top-level module itself
start_idx = 1
# Start building path - add top_name if we have caller_modules
if len(caller_modules) > 0:
path_parts.append(top_name)
# Build path by finding each module's attribute name in its parent
for i in range(start_idx, len(caller_modules)):
current_module = caller_modules[i]
if i == start_idx:
# First module is a direct child of top-level
parent = top_module if top_module else None
attr_name = self._find_module_attribute_name(current_module, parent)
else:
# Find attribute name in previous module
attr_name = self._find_module_attribute_name(
current_module,
caller_modules[i-1]
)
if attr_name:
path_parts.append(attr_name)
# Add the current instance
parent_for_instance = caller_modules[-1] if caller_modules else (top_module if top_module else None)
attr_name = self._find_module_attribute_name(instance, parent_for_instance)
if attr_name:
path_parts.append(attr_name)
elif len(caller_modules) == 0:
# Only add class name if this is the top-level module
path_parts.append(instance.__class__.__name__)
return ".".join(path_parts) if path_parts else instance.__class__.__name__
def _get_usage_snapshot(self):
"""Get current usage from the active usage tracker."""
if dspy.settings.usage_tracker:
return dspy.settings.usage_tracker.get_total_tokens()
return {}
def on_module_start(self, call_id, instance, inputs):
"""Called when a module starts executing."""
# Build descriptive module path
module_path = self._build_module_path(instance)
# Skip if path is None (e.g., internal Predict instances)
if module_path is None:
return
# Store initial usage state for this module call
initial_usage = self._get_usage_snapshot()
self.module_trackers[call_id] = (module_path, initial_usage)
def on_module_end(self, call_id, outputs, exception=None):
"""Called when a module finishes executing."""
if call_id not in self.module_trackers:
return
try:
module_path, initial_usage = self.module_trackers.pop(call_id)
final_usage = self._get_usage_snapshot()
# Ensure we have valid dict structures
if not isinstance(initial_usage, dict):
initial_usage = {}
if not isinstance(final_usage, dict):
final_usage = {}
# Calculate usage difference for this module
module_usage = self._calculate_usage_diff(initial_usage, final_usage)
# Only store if we got valid usage data
if module_usage:
# Store usage for this module (accumulate if called multiple times)
if module_path in self.module_usage:
# Merge with existing usage
self.module_usage[module_path] = self._merge_usage(
self.module_usage[module_path], module_usage
)
else:
self.module_usage[module_path] = module_usage
except Exception:
# Silently ignore errors to avoid cluttering output
pass
def _calculate_usage_diff(self, initial, final):
"""Calculate the difference between two usage dictionaries."""
if not isinstance(final, dict):
return {}
diff = {}
for lm_name, final_usage in final.items():
if not isinstance(final_usage, dict):
continue
initial_usage = initial.get(lm_name, {})
if not isinstance(initial_usage, dict):
initial_usage = {}
diff[lm_name] = {}
for key, final_value in final_usage.items():
initial_value = initial_usage.get(key, 0)
if isinstance(final_value, dict):
# Handle nested dictionaries (like prompt_tokens_details)
diff[lm_name][key] = self._calculate_usage_diff(
initial_value if isinstance(initial_value, dict) else {},
final_value
)
elif isinstance(final_value, (int, float)) and isinstance(initial_value, (int, float)):
diff[lm_name][key] = final_value - initial_value
return diff
def _merge_usage(self, usage1, usage2):
"""Merge two usage dictionaries."""
if not isinstance(usage1, dict):
usage1 = {}
if not isinstance(usage2, dict):
usage2 = {}
merged = {}
all_lms = set(usage1.keys()) | set(usage2.keys())
for lm_name in all_lms:
u1 = usage1.get(lm_name, {})
u2 = usage2.get(lm_name, {})
if not isinstance(u1, dict):
u1 = {}
if not isinstance(u2, dict):
u2 = {}
merged[lm_name] = {}
all_keys = set(u1.keys()) | set(u2.keys())
for key in all_keys:
v1 = u1.get(key, 0)
v2 = u2.get(key, 0)
if isinstance(v1, dict) or isinstance(v2, dict):
merged[lm_name][key] = self._merge_usage(
v1 if isinstance(v1, dict) else {},
v2 if isinstance(v2, dict) else {}
)
elif isinstance(v1, (int, float)) and isinstance(v2, (int, float)):
merged[lm_name][key] = v1 + v2
return merged
def get_module_usage(self):
"""Get the accumulated usage for all tracked modules."""
return dict(self.module_usage)
def get_module_usage_by_path(self, path_prefix):
"""Get usage for modules matching a path prefix."""
return {
module_path: usage
for module_path, usage in self.module_usage.items()
if module_path.startswith(path_prefix)
}
# Example usage with nested modules
if __name__ == "__main__":
import os
# Suppress callback errors for cleaner output
logging.getLogger("dspy.utils.callback").setLevel(logging.ERROR)
# Check for API key
if not os.getenv("OPENAI_API_KEY"):
print("Error: OPENAI_API_KEY environment variable not set.")
print("Please set it with: export OPENAI_API_KEY='your-key'")
exit(1)
# Define nested module structure
class NestedModule(dspy.Module):
"""A nested sub-module with its own predictor."""
def __init__(self):
self.inner = dspy.ChainOfThought("input -> output")
def forward(self, input):
return self.inner(input=input)
class TopLevelModule(dspy.Module):
"""Top-level module with multiple sub-modules."""
def __init__(self):
self.submodule1 = dspy.ChainOfThought("question -> answer")
self.submodule2 = dspy.ChainOfThought("answer -> score")
self.nested = NestedModule()
def forward(self, question):
# First sub-module call
answer = self.submodule1(question=question)
# Second sub-module call
score = self.submodule2(answer=answer.answer)
# Nested module call
nested_result = self.nested(input=answer.answer)
return dspy.Prediction(
answer=answer.answer,
score=score.score,
nested_output=nested_result.output
)
# Initialize the program and tracker
print("Initializing nested module program...")
program = TopLevelModule()
tracker = PerModuleUsageTracker(top_level_module=program)
# Configure DSPy with usage tracking
print("Configuring DSPy with usage tracking...")
dspy.configure(
lm=dspy.LM("openai/gpt-5-nano", cache=False),
track_usage=True,
callbacks=[tracker]
)
# Run the program
print("\nRunning program with question: 'What is the capital of France?'")
result = program(question="What is the capital of France?")
print(f"\nResult: {result.answer}")
print(f"Score: {result.score}")
# Get and display per-module usage as a table
print("\n" + "=" * 70)
print("Per-Module Token Usage Breakdown:")
print("=" * 70)
usage = tracker.get_module_usage()
if usage:
# Prepare table data
table_data = []
for module_path, module_usage in usage.items():
if isinstance(module_usage, dict):
for lm_name, tokens in module_usage.items():
if isinstance(tokens, dict):
table_data.append({
'module': module_path,
'input': tokens.get('prompt_tokens', 0),
'output': tokens.get('completion_tokens', 0),
'total': tokens.get('total_tokens', 0)
})
break # Only use first LM if multiple
# Print table
if table_data:
# Calculate column widths
max_module_len = max(len(str(row['module'])) for row in table_data)
max_input_len = max(len(str(row['input'])) for row in table_data)
max_output_len = max(len(str(row['output'])) for row in table_data)
max_total_len = max(len(str(row['total'])) for row in table_data)
# Header
header = f"{'Module':<{max_module_len}} | {'Input':>{max_input_len}} | {'Output':>{max_output_len}} | {'Total':>{max_total_len}}"
print(f"\n{header}")
print("-" * len(header))
# Rows
for row in table_data:
print(f"{row['module']:<{max_module_len}} | {row['input']:>{max_input_len}} | {row['output']:>{max_output_len}} | {row['total']:>{max_total_len}}")
else:
print("\nNo usage data collected. Make sure track_usage=True is set in dspy.configure()")
print("\n" + "=" * 70)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment