Last active
December 3, 2025 11:11
-
-
Save justanotheratom/372c8248a17ad97ca18b2dcc0ac0ceef to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| DSPy Streaming Demo - Self-Contained Script | |
| This script demonstrates DSPy's streaming capabilities with tool calls and status messages. | |
| It starts a server in the background, runs a demo query, and cleans up on exit. | |
| Setup: | |
| export OPENAI_API_KEY="sk-your-key" | |
| python3 -m venv .venv | |
| source .venv/bin/activate | |
| python3 -m pip install fastapi uvicorn httpx dspy | |
| Run: | |
| python3 dspy_status_streaming.py | |
| OR | |
| python3 dspy_status_streaming.py "Your custom question here" | |
| """ | |
| import asyncio | |
| import json | |
| import multiprocessing | |
| import os | |
| import sys | |
| import time | |
| from datetime import datetime | |
| from typing import Any | |
| # ============================================================================= | |
| # Ignore warnings | |
| # ============================================================================= | |
| import warnings | |
| warnings.filterwarnings("ignore", category=UserWarning, module="pydantic") | |
| warnings.filterwarnings("ignore", category=UserWarning, module="openai") | |
| # ============================================================================= | |
| # Dependency Check - Run before multiprocessing to get clear errors | |
| # ============================================================================= | |
| def check_dependencies(): | |
| """Check all required dependencies are installed.""" | |
| missing = [] | |
| try: | |
| import dspy | |
| except ImportError: | |
| missing.append("dspy") | |
| try: | |
| import fastapi | |
| except ImportError: | |
| missing.append("fastapi") | |
| try: | |
| import uvicorn | |
| except ImportError: | |
| missing.append("uvicorn") | |
| try: | |
| import httpx | |
| except ImportError: | |
| missing.append("httpx") | |
| if missing: | |
| print(f"\033[91m\033[1mERROR: Missing required dependencies:\033[0m {', '.join(missing)}") | |
| print(f"\033[90mInstall them with:\033[0m") | |
| print(f" pip install {' '.join(missing)}") | |
| sys.exit(1) | |
| # Check dependencies immediately on import | |
| check_dependencies() | |
| # ============================================================================= | |
| # ANSI Colors for Pretty Output | |
| # ============================================================================= | |
| class Colors: | |
| RESET = "\033[0m" | |
| BOLD = "\033[1m" | |
| DIM = "\033[2m" | |
| STATUS = "\033[94m" # Blue | |
| SUCCESS = "\033[92m" # Green | |
| ERROR = "\033[91m" # Red | |
| THOUGHT = "\033[95m" # Magenta | |
| ANSWER = "\033[96m" # Cyan | |
| TIMESTAMP = "\033[90m" # Gray | |
| # ============================================================================= | |
| # SERVER CODE | |
| # ============================================================================= | |
| def run_server(port: int = 8000): | |
| """Run the FastAPI server in this process.""" | |
| import dspy | |
| from fastapi import FastAPI | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel | |
| import uvicorn | |
| # Configure DSPy | |
| OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") | |
| if not OPENAI_API_KEY: | |
| print(f"{Colors.ERROR}ERROR: Set OPENAI_API_KEY environment variable{Colors.RESET}") | |
| sys.exit(1) | |
| dspy.configure( | |
| lm=dspy.LM("openai/gpt-4o-mini", api_key=OPENAI_API_KEY, cache=False) | |
| ) | |
| # ------------------------------------------------------------------------- | |
| # Tools - Simulated external services | |
| # ------------------------------------------------------------------------- | |
| def search_web(query: str) -> str: | |
| """Search the web for information on a topic.""" | |
| results = { | |
| "python": "Python is a high-level programming language created by Guido van Rossum in 1991. It emphasizes code readability and supports multiple programming paradigms.", | |
| "machine learning": "Machine learning is a subset of AI that enables systems to learn from data. Popular frameworks include TensorFlow, PyTorch, and scikit-learn.", | |
| "climate": "Climate change refers to long-term shifts in global temperatures. The Paris Agreement aims to limit warming to 1.5°C above pre-industrial levels.", | |
| "space": "SpaceX and NASA are leading space exploration. Recent missions include Mars rovers and the James Webb Space Telescope.", | |
| "dspy": "DSPy is a framework for programming language models. It provides modules like Predict, ChainOfThought, and ReAct for building LM-powered applications.", | |
| } | |
| for key, value in results.items(): | |
| if key in query.lower(): | |
| return value | |
| return f"Search results for '{query}': This is an active area of research with ongoing developments." | |
| def get_weather(location: str) -> str: | |
| """Get current weather for a location.""" | |
| weather_data = { | |
| "new york": {"temp": 72, "condition": "Partly cloudy", "humidity": 65}, | |
| "london": {"temp": 58, "condition": "Rainy", "humidity": 85}, | |
| "tokyo": {"temp": 80, "condition": "Sunny", "humidity": 70}, | |
| "san francisco": {"temp": 65, "condition": "Foggy", "humidity": 78}, | |
| } | |
| loc_lower = location.lower() | |
| for city, data in weather_data.items(): | |
| if city in loc_lower: | |
| return f"Weather in {location}: {data['temp']}°F, {data['condition']}, {data['humidity']}% humidity" | |
| return f"Weather in {location}: 70°F, Clear skies, 60% humidity" | |
| def calculate(expression: str) -> str: | |
| """Evaluate a mathematical expression.""" | |
| try: | |
| allowed_chars = set("0123456789+-*/.() ") | |
| if all(c in allowed_chars for c in expression): | |
| result = eval(expression) | |
| return f"Result: {expression} = {result}" | |
| return "Error: Invalid expression. Only basic arithmetic supported." | |
| except Exception as e: | |
| return f"Calculation error: {str(e)}" | |
| def get_stock_price(symbol: str) -> str: | |
| """Get current stock price for a ticker symbol.""" | |
| stocks = { | |
| "AAPL": {"price": 178.50, "change": +2.35, "pct": "+1.34%"}, | |
| "GOOGL": {"price": 141.25, "change": -0.75, "pct": "-0.53%"}, | |
| "MSFT": {"price": 378.90, "change": +4.20, "pct": "+1.12%"}, | |
| "TSLA": {"price": 245.60, "change": -8.40, "pct": "-3.31%"}, | |
| } | |
| symbol_upper = symbol.upper().strip() | |
| if symbol_upper in stocks: | |
| data = stocks[symbol_upper] | |
| return f"{symbol_upper}: ${data['price']:.2f} ({data['pct']})" | |
| return f"{symbol_upper}: $100.00 (+0.50%) - Simulated data" | |
| def get_current_time(timezone: str = "UTC") -> str: | |
| """Get the current date and time.""" | |
| now = datetime.now() | |
| return f"Current time ({timezone}): {now.strftime('%Y-%m-%d %H:%M:%S')}" | |
| # ------------------------------------------------------------------------- | |
| # Status Message Provider | |
| # ------------------------------------------------------------------------- | |
| class ResearchStatusProvider(dspy.streaming.StatusMessageProvider): | |
| """Provides descriptive status messages for each stage of execution.""" | |
| def tool_start_status_message(self, instance: Any, inputs: dict) -> str: | |
| tool_name = getattr(instance, 'name', 'unknown') | |
| messages = { | |
| "search_web": f"🔍 Searching the web for: {inputs.get('query', 'information')}...", | |
| "get_weather": f"🌤️ Fetching weather for: {inputs.get('location', 'location')}...", | |
| "calculate": f"🧮 Computing: {inputs.get('expression', 'expression')}...", | |
| "get_stock_price": f"📈 Looking up stock: {inputs.get('symbol', 'symbol')}...", | |
| "get_current_time": f"🕐 Getting current time...", | |
| } | |
| return messages.get(tool_name, f"⚙️ Calling tool: {tool_name}...") | |
| def tool_end_status_message(self, outputs: Any) -> str: | |
| output_str = str(outputs)[:80] + "..." if len(str(outputs)) > 80 else str(outputs) | |
| return f"✅ Result: {output_str}" | |
| def lm_start_status_message(self, instance: Any, inputs: dict) -> str: | |
| return "🤖 Thinking..." | |
| def lm_end_status_message(self, outputs: Any) -> str: | |
| return "💭 Generated response" | |
| # ------------------------------------------------------------------------- | |
| # DSPy Program Setup | |
| # ------------------------------------------------------------------------- | |
| tools = [ | |
| dspy.Tool(search_web, name="search_web", desc="Search the web for information"), | |
| dspy.Tool(get_weather, name="get_weather", desc="Get weather for a location"), | |
| dspy.Tool(calculate, name="calculate", desc="Evaluate math like '2 + 2'"), | |
| dspy.Tool(get_stock_price, name="get_stock_price", desc="Get stock price for AAPL, GOOGL, etc."), | |
| dspy.Tool(get_current_time, name="get_current_time", desc="Get current date/time"), | |
| ] | |
| react_agent = dspy.ReAct( | |
| signature="question -> answer", | |
| tools=tools, | |
| max_iters=5, | |
| ) | |
| stream_listeners = [ | |
| dspy.streaming.StreamListener(signature_field_name="next_thought", allow_reuse=True), | |
| dspy.streaming.StreamListener(signature_field_name="answer"), | |
| ] | |
| streaming_agent = dspy.streamify( | |
| react_agent, | |
| stream_listeners=stream_listeners, | |
| status_message_provider=ResearchStatusProvider(), | |
| ) | |
| # ------------------------------------------------------------------------- | |
| # FastAPI App | |
| # ------------------------------------------------------------------------- | |
| app = FastAPI(title="DSPy Research Assistant") | |
| class Query(BaseModel): | |
| question: str | |
| @app.get("/health") | |
| async def health(): | |
| return {"status": "ready"} | |
| @app.post("/v1/research") | |
| async def research_stream(query: Query): | |
| async def event_generator(): | |
| try: | |
| output_stream = streaming_agent(question=query.question) | |
| async for item in output_stream: | |
| if isinstance(item, dspy.streaming.StatusMessage): | |
| payload = {"type": "status", "message": item.message} | |
| yield f"data: {json.dumps(payload)}\n\n" | |
| elif isinstance(item, dspy.streaming.StreamResponse): | |
| payload = { | |
| "type": "token", | |
| "field": item.signature_field_name, | |
| "predictor": item.predict_name, | |
| "chunk": item.chunk, | |
| } | |
| yield f"data: {json.dumps(payload)}\n\n" | |
| elif isinstance(item, dspy.Prediction): | |
| # Extract only serializable fields | |
| data = {k: v for k, v in item.items() if not k.startswith("_")} | |
| payload = {"type": "prediction", "data": data} | |
| yield f"data: {json.dumps(payload)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| except Exception as e: | |
| yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" | |
| return StreamingResponse(event_generator(), media_type="text/event-stream") | |
| # Run server (suppress logs for cleaner output) | |
| uvicorn.run(app, host="127.0.0.1", port=port, log_level="warning") | |
| # ============================================================================= | |
| # CLIENT CODE | |
| # ============================================================================= | |
| def print_status(message: str): | |
| print(f"{Colors.STATUS}{message}{Colors.RESET}") | |
| def print_field_header(field: str): | |
| print(f"\n{Colors.BOLD}📝 [{field.upper()}]{Colors.RESET}") | |
| print(f"{Colors.DIM}{'─'*50}{Colors.RESET}") | |
| def print_prediction(data: dict): | |
| print(f"\n\n{Colors.BOLD}{'='*60}{Colors.RESET}") | |
| print(f"{Colors.SUCCESS}{Colors.BOLD}✨ FINAL ANSWER{Colors.RESET}") | |
| print(f"{Colors.BOLD}{'='*60}{Colors.RESET}\n") | |
| for key, value in data.items(): | |
| if key in ("trajectory",) or key.startswith("_"): | |
| continue | |
| print(f"{Colors.BOLD}{key}:{Colors.RESET}") | |
| print(f" {value}\n") | |
| async def run_client(question: str, server_url: str = "http://127.0.0.1:8000"): | |
| """Connect to server and stream the response.""" | |
| import httpx | |
| url = f"{server_url}/v1/research" | |
| current_field = None | |
| token_count = 0 | |
| status_count = 0 | |
| try: | |
| async with httpx.AsyncClient(timeout=120.0) as client: | |
| async with client.stream( | |
| "POST", url, | |
| json={"question": question}, | |
| headers={"Accept": "text/event-stream"}, | |
| ) as response: | |
| if response.status_code != 200: | |
| print(f"{Colors.ERROR}Server error: {response.status_code}{Colors.RESET}") | |
| return | |
| async for line in response.aiter_lines(): | |
| if not line or not line.startswith("data: "): | |
| continue | |
| data_str = line[6:] | |
| if data_str.strip() == "[DONE]": | |
| print(f"\n\n{Colors.SUCCESS}✅ Stream complete!{Colors.RESET}") | |
| print(f"{Colors.DIM} {status_count} status messages, {token_count} tokens{Colors.RESET}") | |
| break | |
| try: | |
| payload = json.loads(data_str) | |
| if payload["type"] == "status": | |
| print_status(payload["message"]) | |
| status_count += 1 | |
| elif payload["type"] == "token": | |
| field = payload["field"] | |
| chunk = payload["chunk"] | |
| if field != current_field: | |
| if current_field: | |
| print() | |
| print_field_header(field) | |
| current_field = field | |
| color = Colors.THOUGHT if "thought" in field.lower() else Colors.ANSWER | |
| print(f"{color}{chunk}{Colors.RESET}", end="", flush=True) | |
| token_count += 1 | |
| elif payload["type"] == "prediction": | |
| print_prediction(payload["data"]) | |
| elif payload["type"] == "error": | |
| print(f"{Colors.ERROR}Error: {payload['message']}{Colors.RESET}") | |
| except (json.JSONDecodeError, KeyError): | |
| continue | |
| except httpx.ConnectError: | |
| print(f"{Colors.ERROR}Could not connect to server{Colors.RESET}") | |
| except Exception as e: | |
| print(f"{Colors.ERROR}Error: {e}{Colors.RESET}") | |
| async def wait_for_server(url: str, timeout: float = 30.0): | |
| """Wait for server to become ready.""" | |
| import httpx | |
| start = time.time() | |
| while time.time() - start < timeout: | |
| try: | |
| async with httpx.AsyncClient() as client: | |
| resp = await client.get(f"{url}/health", timeout=2.0) | |
| if resp.status_code == 200: | |
| return True | |
| except: | |
| pass | |
| await asyncio.sleep(0.5) | |
| return False | |
| # ============================================================================= | |
| # MAIN ENTRY POINT | |
| # ============================================================================= | |
| def main(): | |
| # Check for API key | |
| if not os.environ.get("OPENAI_API_KEY"): | |
| print(f"{Colors.ERROR}{Colors.BOLD}ERROR:{Colors.RESET} Set OPENAI_API_KEY environment variable") | |
| print(f"{Colors.DIM} export OPENAI_API_KEY='sk-your-key-here'{Colors.RESET}") | |
| sys.exit(1) | |
| # Disable colors if not a TTY | |
| if not sys.stdout.isatty(): | |
| for attr in dir(Colors): | |
| if not attr.startswith("_"): | |
| setattr(Colors, attr, "") | |
| # Get question from args or use default | |
| if len(sys.argv) > 1: | |
| question = " ".join(sys.argv[1:]) | |
| else: | |
| question = "What's the weather in Tokyo and what's AAPL stock price? Also tell me the current time." | |
| port = 8000 | |
| server_url = f"http://127.0.0.1:{port}" | |
| # Print header | |
| print(f"\n{Colors.BOLD}{'='*60}{Colors.RESET}") | |
| print(f"{Colors.BOLD} 🔬 DSPy Streaming Demo - Research Assistant{Colors.RESET}") | |
| print(f"{Colors.BOLD}{'='*60}{Colors.RESET}\n") | |
| # Start server process | |
| # Use 'spawn' context for macOS compatibility | |
| print(f"{Colors.DIM}Starting server...{Colors.RESET}") | |
| ctx = multiprocessing.get_context('spawn') | |
| server_process = ctx.Process(target=run_server, args=(port,), daemon=True) | |
| server_process.start() | |
| try: | |
| # Wait for server to be ready | |
| print(f"{Colors.DIM}Waiting for server to be ready...{Colors.RESET}") | |
| ready = asyncio.run(wait_for_server(server_url)) | |
| if not ready: | |
| print(f"{Colors.ERROR}Server failed to start within timeout{Colors.RESET}") | |
| sys.exit(1) | |
| print(f"{Colors.SUCCESS}Server ready!{Colors.RESET}\n") | |
| print(f"{Colors.BOLD}Question:{Colors.RESET} {question}") | |
| print(f"{Colors.DIM}{'─'*60}{Colors.RESET}") | |
| print(f"{Colors.BOLD}📡 Streaming Response:{Colors.RESET}\n") | |
| # Run the client | |
| asyncio.run(run_client(question, server_url)) | |
| except KeyboardInterrupt: | |
| print(f"\n{Colors.DIM}Interrupted by user{Colors.RESET}") | |
| finally: | |
| # Clean up server process | |
| print(f"\n{Colors.DIM}Shutting down server...{Colors.RESET}") | |
| if server_process.is_alive(): | |
| server_process.terminate() | |
| server_process.join(timeout=5) | |
| if server_process.is_alive(): | |
| server_process.kill() | |
| print(f"{Colors.SUCCESS}Done!{Colors.RESET}\n") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment