Last active
December 10, 2025 04:10
-
-
Save justanotheratom/20893a4c04eddb47612247ea4e343c25 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| Per-Module Usage Tracker for DSPy | |
| Tracks token usage for individual sub-modules in multi-module DSPy programs. | |
| Quick Start: | |
| 1. pip install dspy | |
| 2. export OPENAI_API_KEY="your-key" | |
| 3. Run: python per_module_usage_tracker.py | |
| """ | |
| import warnings | |
| import logging | |
| # Suppress warnings before importing modules that may emit them | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| warnings.filterwarnings("ignore", message=".*Pydantic.*") | |
| import dspy | |
| from dspy.utils.callback import BaseCallback | |
| from collections import defaultdict | |
| class PerModuleUsageTracker(BaseCallback): | |
| """Tracks token usage for each sub-module separately with descriptive names.""" | |
| def __init__(self, top_level_module=None): | |
| self.module_usage = defaultdict(dict) # module_path -> usage dict | |
| self.module_trackers = {} # call_id -> (module_path, initial_usage_state) | |
| self.top_level_module = top_level_module # Optional: root module for better naming | |
| def _find_module_attribute_name(self, module_instance, parent_module): | |
| """Find the attribute name of a module instance in its parent module.""" | |
| if parent_module is None: | |
| return None | |
| # Check direct attributes | |
| for attr_name, attr_value in parent_module.__dict__.items(): | |
| if attr_value is module_instance: | |
| return attr_name | |
| # Check nested structures (lists, dicts, nested modules) | |
| # This handles cases like parent.submodule.inner | |
| for attr_name, attr_value in parent_module.__dict__.items(): | |
| if isinstance(attr_value, dspy.Module): | |
| # Recursively check nested modules | |
| nested_name = self._find_module_attribute_name(module_instance, attr_value) | |
| if nested_name: | |
| return f"{attr_name}.{nested_name}" | |
| elif isinstance(attr_value, (list, tuple)): | |
| for idx, item in enumerate(attr_value): | |
| if item is module_instance: | |
| return f"{attr_name}[{idx}]" | |
| elif isinstance(attr_value, dict): | |
| for key, item in attr_value.items(): | |
| if item is module_instance: | |
| return f"{attr_name}['{key}']" | |
| return None | |
| def _has_direct_predictors(self, instance): | |
| """Check if a module has direct Predict instances as attributes (not recursively).""" | |
| from dspy.predict.predict import Predict | |
| try: | |
| # Check only direct attributes, not recursive ones | |
| for attr_name, attr_value in instance.__dict__.items(): | |
| if isinstance(attr_value, Predict): | |
| return True | |
| return False | |
| except Exception: | |
| return False | |
| def _build_module_path(self, instance): | |
| """Build a descriptive path for a module instance.""" | |
| from dspy.predict.predict import Predict | |
| # Skip tracking internal Predict instances (they're tracked via their parent modules) | |
| if isinstance(instance, Predict): | |
| return None | |
| # Skip the top-level module itself (we only track sub-modules) | |
| if self.top_level_module is not None and instance is self.top_level_module: | |
| return None | |
| # Skip container modules that don't have direct predictors | |
| # (they're just wrappers around other modules) | |
| if not self._has_direct_predictors(instance): | |
| return None | |
| caller_modules = dspy.settings.caller_modules or [] | |
| # Filter out Predict instances and container modules without direct predictors | |
| # Keep only modules that have direct predictors (or are the top-level module) | |
| filtered_caller_modules = [] | |
| for m in caller_modules: | |
| if isinstance(m, Predict): | |
| continue | |
| # Keep top-level module even if it doesn't have direct predictors | |
| if self.top_level_module and m is self.top_level_module: | |
| filtered_caller_modules.append(m) | |
| # Keep modules that have direct predictors | |
| elif self._has_direct_predictors(m): | |
| filtered_caller_modules.append(m) | |
| caller_modules = filtered_caller_modules | |
| # If this is the top-level module, return its class name | |
| if len(caller_modules) == 0: | |
| return instance.__class__.__name__ | |
| # Build path from the module hierarchy | |
| path_parts = [] | |
| # Determine the top-level module name | |
| if self.top_level_module: | |
| top_module = self.top_level_module | |
| top_name = top_module.__class__.__name__ | |
| elif len(caller_modules) > 0: | |
| top_module = caller_modules[0] | |
| top_name = top_module.__class__.__name__ | |
| else: | |
| top_module = None | |
| top_name = instance.__class__.__name__ | |
| # Check if first caller_module is the top-level module (to avoid duplication) | |
| start_idx = 0 | |
| if caller_modules and self.top_level_module and caller_modules[0] is self.top_level_module: | |
| # Skip the first element since it's the top-level module itself | |
| start_idx = 1 | |
| # Start building path - add top_name if we have caller_modules | |
| if len(caller_modules) > 0: | |
| path_parts.append(top_name) | |
| # Build path by finding each module's attribute name in its parent | |
| for i in range(start_idx, len(caller_modules)): | |
| current_module = caller_modules[i] | |
| if i == start_idx: | |
| # First module is a direct child of top-level | |
| parent = top_module if top_module else None | |
| attr_name = self._find_module_attribute_name(current_module, parent) | |
| else: | |
| # Find attribute name in previous module | |
| attr_name = self._find_module_attribute_name( | |
| current_module, | |
| caller_modules[i-1] | |
| ) | |
| if attr_name: | |
| path_parts.append(attr_name) | |
| # Add the current instance | |
| parent_for_instance = caller_modules[-1] if caller_modules else (top_module if top_module else None) | |
| attr_name = self._find_module_attribute_name(instance, parent_for_instance) | |
| if attr_name: | |
| path_parts.append(attr_name) | |
| elif len(caller_modules) == 0: | |
| # Only add class name if this is the top-level module | |
| path_parts.append(instance.__class__.__name__) | |
| return ".".join(path_parts) if path_parts else instance.__class__.__name__ | |
| def _get_usage_snapshot(self): | |
| """Get current usage from the active usage tracker.""" | |
| if dspy.settings.usage_tracker: | |
| return dspy.settings.usage_tracker.get_total_tokens() | |
| return {} | |
| def on_module_start(self, call_id, instance, inputs): | |
| """Called when a module starts executing.""" | |
| # Build descriptive module path | |
| module_path = self._build_module_path(instance) | |
| # Skip if path is None (e.g., internal Predict instances) | |
| if module_path is None: | |
| return | |
| # Store initial usage state for this module call | |
| initial_usage = self._get_usage_snapshot() | |
| self.module_trackers[call_id] = (module_path, initial_usage) | |
| def on_module_end(self, call_id, outputs, exception=None): | |
| """Called when a module finishes executing.""" | |
| if call_id not in self.module_trackers: | |
| return | |
| try: | |
| module_path, initial_usage = self.module_trackers.pop(call_id) | |
| final_usage = self._get_usage_snapshot() | |
| # Ensure we have valid dict structures | |
| if not isinstance(initial_usage, dict): | |
| initial_usage = {} | |
| if not isinstance(final_usage, dict): | |
| final_usage = {} | |
| # Calculate usage difference for this module | |
| module_usage = self._calculate_usage_diff(initial_usage, final_usage) | |
| # Only store if we got valid usage data | |
| if module_usage: | |
| # Store usage for this module (accumulate if called multiple times) | |
| if module_path in self.module_usage: | |
| # Merge with existing usage | |
| self.module_usage[module_path] = self._merge_usage( | |
| self.module_usage[module_path], module_usage | |
| ) | |
| else: | |
| self.module_usage[module_path] = module_usage | |
| except Exception: | |
| # Silently ignore errors to avoid cluttering output | |
| pass | |
| def _calculate_usage_diff(self, initial, final): | |
| """Calculate the difference between two usage dictionaries.""" | |
| if not isinstance(final, dict): | |
| return {} | |
| diff = {} | |
| for lm_name, final_usage in final.items(): | |
| if not isinstance(final_usage, dict): | |
| continue | |
| initial_usage = initial.get(lm_name, {}) | |
| if not isinstance(initial_usage, dict): | |
| initial_usage = {} | |
| diff[lm_name] = {} | |
| for key, final_value in final_usage.items(): | |
| initial_value = initial_usage.get(key, 0) | |
| if isinstance(final_value, dict): | |
| # Handle nested dictionaries (like prompt_tokens_details) | |
| diff[lm_name][key] = self._calculate_usage_diff( | |
| initial_value if isinstance(initial_value, dict) else {}, | |
| final_value | |
| ) | |
| elif isinstance(final_value, (int, float)) and isinstance(initial_value, (int, float)): | |
| diff[lm_name][key] = final_value - initial_value | |
| return diff | |
| def _merge_usage(self, usage1, usage2): | |
| """Merge two usage dictionaries.""" | |
| if not isinstance(usage1, dict): | |
| usage1 = {} | |
| if not isinstance(usage2, dict): | |
| usage2 = {} | |
| merged = {} | |
| all_lms = set(usage1.keys()) | set(usage2.keys()) | |
| for lm_name in all_lms: | |
| u1 = usage1.get(lm_name, {}) | |
| u2 = usage2.get(lm_name, {}) | |
| if not isinstance(u1, dict): | |
| u1 = {} | |
| if not isinstance(u2, dict): | |
| u2 = {} | |
| merged[lm_name] = {} | |
| all_keys = set(u1.keys()) | set(u2.keys()) | |
| for key in all_keys: | |
| v1 = u1.get(key, 0) | |
| v2 = u2.get(key, 0) | |
| if isinstance(v1, dict) or isinstance(v2, dict): | |
| merged[lm_name][key] = self._merge_usage( | |
| v1 if isinstance(v1, dict) else {}, | |
| v2 if isinstance(v2, dict) else {} | |
| ) | |
| elif isinstance(v1, (int, float)) and isinstance(v2, (int, float)): | |
| merged[lm_name][key] = v1 + v2 | |
| return merged | |
| def get_module_usage(self): | |
| """Get the accumulated usage for all tracked modules.""" | |
| return dict(self.module_usage) | |
| def get_module_usage_by_path(self, path_prefix): | |
| """Get usage for modules matching a path prefix.""" | |
| return { | |
| module_path: usage | |
| for module_path, usage in self.module_usage.items() | |
| if module_path.startswith(path_prefix) | |
| } | |
| # Example usage with nested modules | |
| if __name__ == "__main__": | |
| import os | |
| # Suppress callback errors for cleaner output | |
| logging.getLogger("dspy.utils.callback").setLevel(logging.ERROR) | |
| # Check for API key | |
| if not os.getenv("OPENAI_API_KEY"): | |
| print("Error: OPENAI_API_KEY environment variable not set.") | |
| print("Please set it with: export OPENAI_API_KEY='your-key'") | |
| exit(1) | |
| # Define nested module structure | |
| class NestedModule(dspy.Module): | |
| """A nested sub-module with its own predictor.""" | |
| def __init__(self): | |
| self.inner = dspy.ChainOfThought("input -> output") | |
| def forward(self, input): | |
| return self.inner(input=input) | |
| class TopLevelModule(dspy.Module): | |
| """Top-level module with multiple sub-modules.""" | |
| def __init__(self): | |
| self.submodule1 = dspy.ChainOfThought("question -> answer") | |
| self.submodule2 = dspy.ChainOfThought("answer -> score") | |
| self.nested = NestedModule() | |
| def forward(self, question): | |
| # First sub-module call | |
| answer = self.submodule1(question=question) | |
| # Second sub-module call | |
| score = self.submodule2(answer=answer.answer) | |
| # Nested module call | |
| nested_result = self.nested(input=answer.answer) | |
| return dspy.Prediction( | |
| answer=answer.answer, | |
| score=score.score, | |
| nested_output=nested_result.output | |
| ) | |
| # Initialize the program and tracker | |
| print("Initializing nested module program...") | |
| program = TopLevelModule() | |
| tracker = PerModuleUsageTracker(top_level_module=program) | |
| # Configure DSPy with usage tracking | |
| print("Configuring DSPy with usage tracking...") | |
| dspy.configure( | |
| lm=dspy.LM("openai/gpt-5-nano", cache=False), | |
| track_usage=True, | |
| callbacks=[tracker] | |
| ) | |
| # Run the program | |
| print("\nRunning program with question: 'What is the capital of France?'") | |
| result = program(question="What is the capital of France?") | |
| print(f"\nResult: {result.answer}") | |
| print(f"Score: {result.score}") | |
| # Get and display per-module usage as a table | |
| print("\n" + "=" * 70) | |
| print("Per-Module Token Usage Breakdown:") | |
| print("=" * 70) | |
| usage = tracker.get_module_usage() | |
| if usage: | |
| # Prepare table data | |
| table_data = [] | |
| for module_path, module_usage in usage.items(): | |
| if isinstance(module_usage, dict): | |
| for lm_name, tokens in module_usage.items(): | |
| if isinstance(tokens, dict): | |
| table_data.append({ | |
| 'module': module_path, | |
| 'input': tokens.get('prompt_tokens', 0), | |
| 'output': tokens.get('completion_tokens', 0), | |
| 'total': tokens.get('total_tokens', 0) | |
| }) | |
| break # Only use first LM if multiple | |
| # Print table | |
| if table_data: | |
| # Calculate column widths | |
| max_module_len = max(len(str(row['module'])) for row in table_data) | |
| max_input_len = max(len(str(row['input'])) for row in table_data) | |
| max_output_len = max(len(str(row['output'])) for row in table_data) | |
| max_total_len = max(len(str(row['total'])) for row in table_data) | |
| # Header | |
| header = f"{'Module':<{max_module_len}} | {'Input':>{max_input_len}} | {'Output':>{max_output_len}} | {'Total':>{max_total_len}}" | |
| print(f"\n{header}") | |
| print("-" * len(header)) | |
| # Rows | |
| for row in table_data: | |
| print(f"{row['module']:<{max_module_len}} | {row['input']:>{max_input_len}} | {row['output']:>{max_output_len}} | {row['total']:>{max_total_len}}") | |
| else: | |
| print("\nNo usage data collected. Make sure track_usage=True is set in dspy.configure()") | |
| print("\n" + "=" * 70) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment