Created
January 16, 2026 21:57
-
-
Save rsarv3006/e5cbba4e00799fd399adb940fb2ff2e1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| Distributed Zig function scoring with Ollama endpoints. | |
| Scores Zig functions for quality, complexity, and production-readiness. | |
| """ | |
| import requests | |
| from typing import Dict, List, Any | |
| import time | |
| import queue | |
| import threading | |
| import sys | |
| import json | |
| import os | |
| # Configuration | |
| INPUT_FILE = "extracted_functions.txt" | |
| OUTPUT_FILE = "scored_zig_functions.json" | |
| CHECKPOINT_FILE = ".scored_functions_checkpoint.json" | |
| # Ollama endpoints - add as many as you want! | |
| OLLAMA_ENDPOINTS = [ | |
| # {"name": "P40-Local", "url": "http://localhost:11434/api/generate", "workers": 2}, | |
| {"name": "MacBook-22", "url": "http://10.0.0.22:11434/api/generate", "workers": 2}, | |
| {"name": "MacBook-31", "url": "http://10.0.0.31:11434/api/generate", "workers": 2}, | |
| ] | |
| WORKERS_PER_ENDPOINT = 2 | |
| # Ollama model to use | |
| OLLAMA_MODEL_PREFERENCES = [ | |
| "qwen2.5-coder:7b-instruct", | |
| ] | |
| CHECKPOINT_INTERVAL = 200 | |
| WORK_QUEUE_SIZE = 500 | |
| SCORING_PROMPT = """Rate this Zig function (0-100 points). | |
| COMPLETENESS (0-25): Full implementation? | |
| COMPLEXITY (0-20): Non-trivial logic? | |
| REAL_WORLD (0-20): Production-worthy? (Error handling with !, allocator patterns, etc.) | |
| CODE_QUALITY (0-20): Idiomatic Zig? (Proper error unions, memory management, etc.) | |
| EDUCATIONAL (0-15): Teaches Zig concepts? (Comptime, optionals, slices, etc.) | |
| Function: | |
| ```zig | |
| {code} | |
| ``` | |
| JSON output only: | |
| {{ | |
| "completeness": <0-25>, | |
| "complexity": <0-20>, | |
| "realWorld": <0-20>, | |
| "quality": <0-20>, | |
| "educational": <0-15>, | |
| "overallScore": <sum> | |
| }} | |
| """ | |
| def parse_zig_functions(file_path: str) -> List[Dict[str, Any]]: | |
| """Parse extracted_functions.txt.""" | |
| functions = [] | |
| current_func = {} | |
| in_code = False | |
| code_lines = [] | |
| with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: | |
| for line in f: | |
| line = line.rstrip('\n') | |
| if line.startswith('===FUNCTION==='): | |
| if current_func and code_lines: | |
| current_func['code'] = '\n'.join(code_lines) | |
| functions.append(current_func) | |
| current_func = {} | |
| code_lines = [] | |
| in_code = False | |
| elif line.startswith('Name:'): | |
| current_func['function_name'] = line.split(':', 1)[1].strip() | |
| elif line.startswith('Repo:'): | |
| current_func['repository'] = line.split(':', 1)[1].strip() | |
| elif line.startswith('File:'): | |
| current_func['file'] = line.split(':', 1)[1].strip() | |
| elif line.startswith('Code:'): | |
| in_code = True | |
| elif in_code: | |
| code_lines.append(line) | |
| if current_func and code_lines: | |
| current_func['code'] = '\n'.join(code_lines) | |
| functions.append(current_func) | |
| return functions | |
| def regex_prefilter(functions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| """Fast regex-based filtering for Zig functions.""" | |
| print("Running regex pre-filter (no GPU)...") | |
| filtered = [] | |
| rejected_test = 0 | |
| rejected_empty = 0 | |
| for func in functions: | |
| code = func.get('code', '') | |
| fname = func.get('function_name', '').lower() | |
| # Reject test functions | |
| if any([ | |
| fname.startswith('test'), | |
| fname.startswith('benchmark'), | |
| 'test.zig' in func.get('file', ''), | |
| ]): | |
| rejected_test += 1 | |
| continue | |
| # Reject empty functions | |
| lines = [l.strip() for l in code.split('\n') if l.strip() | |
| and not l.strip().startswith('//')] | |
| func_body = [l for l in lines if not l.startswith( | |
| 'fn') and not l.startswith('pub fn') and l not in ['{', '}']] | |
| if len(func_body) == 0: | |
| rejected_empty += 1 | |
| continue | |
| # Removed trivial filter - let the model score short functions | |
| # Some elegant algorithms are concise (like LeetCode solutions) | |
| filtered.append(func) | |
| print(f"✓ Regex pre-filter complete:") | |
| print(f" Input: {len(functions)}") | |
| print(f" Rejected (test): {rejected_test}") | |
| print(f" Rejected (empty): {rejected_empty}") | |
| print(f" Remaining: {len(filtered)} (let model score short functions)") | |
| print() | |
| return filtered | |
| def check_ollama_endpoint(endpoint: Dict[str, str]) -> bool: | |
| """Check if an Ollama endpoint is reachable.""" | |
| try: | |
| tags_url = endpoint["url"].replace("/api/generate", "/api/tags") | |
| print(f" [{endpoint['name']}] Checking { | |
| tags_url}...", end=" ", flush=True) | |
| response = requests.get(tags_url, timeout=5) | |
| if response.status_code == 200: | |
| data = response.json() | |
| models = [m["name"] for m in data.get("models", [])] | |
| for model in OLLAMA_MODEL_PREFERENCES: | |
| if model in models: | |
| endpoint["model"] = model | |
| print(f"✓ (using {model})") | |
| return True | |
| print(f"\n⚠️ None of the preferred models found") | |
| print(f" Available: {', '.join(models[:5])}") | |
| return False | |
| else: | |
| print(f"\n⚠️ HTTP {response.status_code}") | |
| return False | |
| except Exception as e: | |
| print(f"\n⚠️ Connection failed: {str(e)[:100]}") | |
| return False | |
| def call_ollama(code: str, endpoint_url: str, model: str) -> Dict[str, Any]: | |
| """Score function using Ollama.""" | |
| if len(code) > 1000: | |
| code = code[:1000] + "\n... (truncated)" | |
| prompt = SCORING_PROMPT.format(code=code) | |
| response = requests.post( | |
| endpoint_url, | |
| json={ | |
| "model": model, | |
| "prompt": prompt, | |
| "stream": False, | |
| "system": "You are a Zig code quality evaluator. Score each criterion independently and respond with valid JSON only.", | |
| "options": { | |
| "temperature": 0.7, | |
| "num_predict": 100, | |
| } | |
| }, | |
| timeout=300 | |
| ) | |
| response.raise_for_status() | |
| result = response.json() | |
| text = result.get("response", "") | |
| # Parse JSON | |
| text = text.strip() | |
| if text.startswith("```"): | |
| lines = text.split("\n") | |
| text = "\n".join(lines[1:-1]) if len(lines) > 2 else text | |
| start_idx = text.find("{") | |
| end_idx = text.rfind("}") + 1 | |
| if start_idx != -1 and end_idx > start_idx: | |
| json_str = text[start_idx:end_idx] | |
| return json.loads(json_str) | |
| else: | |
| raise ValueError(f"No JSON found in Ollama response") | |
| def producer_worker(functions: List[Dict[str, Any]], work_queue: queue.Queue, stats: Dict, num_workers: int): | |
| """Producer: Feed functions into work queue.""" | |
| print("[PRODUCER] Starting...") | |
| for i, func in enumerate(functions): | |
| work_queue.put(func) | |
| if (i + 1) % 1000 == 0: | |
| print(f"[PRODUCER] Queued {i + 1}/{len(functions)}") | |
| # Signal all workers to stop | |
| for _ in range(num_workers): | |
| work_queue.put(None) | |
| stats['queued'] = len(functions) | |
| print(f"[PRODUCER] Complete! Queued {len(functions)} functions") | |
| def ollama_scorer_worker(endpoint: Dict[str, str], work_queue: queue.Queue, results: List, | |
| results_lock: threading.Lock, stats: Dict, stats_lock: threading.Lock): | |
| """Ollama scorer: Pull from queue and score with Ollama endpoint.""" | |
| name = endpoint["name"] | |
| url = endpoint["url"] | |
| model = endpoint.get("model", OLLAMA_MODEL_PREFERENCES[0]) | |
| print(f"[{name}] Scorer starting...") | |
| scored_count = 0 | |
| worker_start = time.time() | |
| while True: | |
| func = work_queue.get() | |
| if func is None: | |
| work_queue.task_done() | |
| break | |
| retry_count = func.get('_retry_count', 0) | |
| try: | |
| score_data = call_ollama(func['code'], url, model) | |
| scored_func = func.copy() | |
| scored_func.pop('_retry_count', None) | |
| scored_func["overallScore"] = int(score_data["overallScore"]) | |
| scored_func["scoreBreakdown"] = { | |
| "completeness": score_data.get("completeness", 0), | |
| "complexity": score_data.get("complexity", 0), | |
| "realWorld": score_data.get("realWorld", 0), | |
| "quality": score_data.get("quality", 0), | |
| "educational": score_data.get("educational", 0) | |
| } | |
| scored_func["scorer"] = name | |
| with results_lock: | |
| results.append(scored_func) | |
| scored_count += 1 | |
| with stats_lock: | |
| stats[f'{name}_scored'] = scored_count | |
| if scored_count % 100 == 0: | |
| worker_elapsed = time.time() - worker_start | |
| rate = scored_count / worker_elapsed | |
| total_scored = len(results) | |
| print(f"[{name}] {scored_count} scored | { | |
| rate:.2f}/s | Total: {total_scored:,}") | |
| except Exception as e: | |
| func['_retry_count'] = retry_count + 1 | |
| work_queue.put(func) | |
| if retry_count > 0 and retry_count % 5 == 0: | |
| print(f"[{name}] Error scoring {func['function_name'] | |
| }, retry {retry_count + 1}: {str(e)[:100]}") | |
| work_queue.task_done() | |
| with stats_lock: | |
| stats[f'{name}_scored'] = scored_count | |
| print(f"[{name}] Complete! Scored {scored_count} functions") | |
| def checkpoint_worker(results: List, results_lock: threading.Lock, stop_event: threading.Event): | |
| """Periodic checkpoint saver.""" | |
| while not stop_event.is_set(): | |
| time.sleep(300) # Every 5 minutes | |
| with results_lock: | |
| if results: | |
| with open(CHECKPOINT_FILE, 'w') as f: | |
| json.dump({"scored_functions": results}, f, indent=2) | |
| print(f"[CHECKPOINT] Saved {len(results)} results") | |
| def progress_tracker(results: List, results_lock: threading.Lock, stats: Dict, | |
| stats_lock: threading.Lock, total_functions: int, | |
| start_time: float, stop_event: threading.Event, initial_count: int): | |
| """Track and display progress stats.""" | |
| last_count = 0 | |
| last_update_time = start_time | |
| session_start_count = initial_count | |
| while not stop_event.is_set(): | |
| time.sleep(60) # Update every 60 seconds | |
| with results_lock: | |
| current_count = len(results) | |
| if current_count == 0: | |
| continue | |
| elapsed = time.time() - start_time | |
| interval = time.time() - last_update_time | |
| completed = current_count | |
| session_completed = completed - session_start_count | |
| remaining = total_functions - completed | |
| session_rate = session_completed / elapsed if elapsed > 0 else 0 | |
| recent_rate = (current_count - last_count) / \ | |
| interval if interval > 0 else 0 | |
| rate_for_eta = recent_rate if recent_rate > 0 else session_rate | |
| eta_seconds = remaining / rate_for_eta if rate_for_eta > 0 else 0 | |
| eta_hours = eta_seconds / 3600 | |
| with stats_lock: | |
| worker_stats = [(k.replace('_scored', ''), v) | |
| for k, v in stats.items() if k.endswith('_scored')] | |
| worker_stats.sort(key=lambda x: x[1], reverse=True) | |
| print(f"\n{'='*70}") | |
| print(f"PROGRESS: { | |
| completed:,}/{total_functions:,} ({completed/total_functions*100:.1f}%)") | |
| print(f"{'='*70}") | |
| print(f"Session: {session_completed:,} functions in { | |
| elapsed/3600:.1f}h | {session_rate:.2f}/s avg") | |
| print(f"ETA: {eta_hours:.1f}h | Throughput: { | |
| recent_rate:.2f}/s recent") | |
| print(f"\nPer-worker stats:") | |
| for name, count in worker_stats: | |
| pct = (count / completed * 100) if completed > 0 else 0 | |
| worker_rate = count / elapsed if elapsed > 0 else 0 | |
| print(f" {name:20s} {count:6,} ({pct:5.1f}%) | { | |
| worker_rate:.2f}/s") | |
| print(f"{'='*70}\n") | |
| last_count = current_count | |
| last_update_time = time.time() | |
| def main(): | |
| if not os.path.exists(INPUT_FILE): | |
| print(f"Error: {INPUT_FILE} not found") | |
| sys.exit(1) | |
| print("=" * 70) | |
| print("Distributed Zig Function Scoring") | |
| print("=" * 70) | |
| # Check which Ollama endpoints are available | |
| print("Checking Ollama endpoints...") | |
| available_endpoints = [] | |
| for endpoint in OLLAMA_ENDPOINTS: | |
| if check_ollama_endpoint(endpoint): | |
| available_endpoints.append(endpoint) | |
| if not available_endpoints: | |
| print("\n❌ No Ollama endpoints available!") | |
| sys.exit(1) | |
| print(f"\n{len(available_endpoints)} Ollama endpoint(s) available") | |
| print("=" * 70) | |
| print() | |
| # Parse and filter | |
| print("Parsing functions...") | |
| functions = parse_zig_functions(INPUT_FILE) | |
| print(f"Found {len(functions)} functions") | |
| print() | |
| functions = regex_prefilter(functions) | |
| if len(functions) == 0: | |
| print("No functions remaining after filter!") | |
| sys.exit(1) | |
| # Check for checkpoint and resume | |
| existing_results = [] | |
| scored_function_keys = set() | |
| original_total_functions = len(functions) | |
| if os.path.exists(CHECKPOINT_FILE): | |
| print("⚠️ Checkpoint found! Loading previous results...") | |
| try: | |
| with open(CHECKPOINT_FILE, 'r') as f: | |
| checkpoint_data = json.load(f) | |
| existing_results = checkpoint_data.get("scored_functions", []) | |
| scored_function_keys = set( | |
| (f.get("repository", ""), f.get( | |
| "file", ""), f.get("function_name", "")) | |
| for f in existing_results | |
| ) | |
| print(f"✓ Loaded {len(existing_results)} already-scored functions") | |
| functions = [ | |
| f for f in functions | |
| if (f.get("repository", ""), f.get("file", ""), f.get("function_name", "")) | |
| not in scored_function_keys | |
| ] | |
| print(f"✓ {len(functions)} functions remaining to score") | |
| if len(functions) == 0: | |
| print("\n✓ All functions already scored!") | |
| return | |
| except Exception as e: | |
| print(f"⚠️ Error loading checkpoint: {e}") | |
| existing_results = [] | |
| # Calculate total workers | |
| total_workers = sum(endpoint.get("workers", WORKERS_PER_ENDPOINT) | |
| for endpoint in available_endpoints) | |
| worker_breakdown = ", ".join([f"{ep['name']}:{ep.get( | |
| 'workers', WORKERS_PER_ENDPOINT)}" for ep in available_endpoints]) | |
| print(f"Starting distributed scoring on {len(functions)} functions...") | |
| print(f"Workers: {total_workers} total ({worker_breakdown})") | |
| print() | |
| # Setup | |
| work_queue = queue.Queue(maxsize=WORK_QUEUE_SIZE) | |
| results = existing_results | |
| results_lock = threading.Lock() | |
| stats = {} | |
| stats_lock = threading.Lock() | |
| stop_checkpoint = threading.Event() | |
| stop_progress = threading.Event() | |
| # Start producer | |
| producer = threading.Thread( | |
| target=producer_worker, | |
| args=(functions, work_queue, stats, total_workers), | |
| daemon=True | |
| ) | |
| # Start Ollama workers | |
| ollama_workers = [] | |
| for endpoint in available_endpoints: | |
| num_workers = endpoint.get("workers", WORKERS_PER_ENDPOINT) | |
| for worker_id in range(num_workers): | |
| endpoint_copy = endpoint.copy() | |
| if num_workers > 1: | |
| endpoint_copy["name"] = f"{endpoint['name']}-W{worker_id+1}" | |
| worker = threading.Thread( | |
| target=ollama_scorer_worker, | |
| args=(endpoint_copy, work_queue, results, | |
| results_lock, stats, stats_lock), | |
| daemon=True | |
| ) | |
| ollama_workers.append(worker) | |
| checkpoint_thread = threading.Thread( | |
| target=checkpoint_worker, | |
| args=(results, results_lock, stop_checkpoint), | |
| daemon=True | |
| ) | |
| progress_thread = threading.Thread( | |
| target=progress_tracker, | |
| args=(results, results_lock, stats, stats_lock, original_total_functions, | |
| time.time(), stop_progress, len(existing_results)), | |
| daemon=True | |
| ) | |
| start_time = time.time() | |
| producer.start() | |
| for worker in ollama_workers: | |
| worker.start() | |
| checkpoint_thread.start() | |
| progress_thread.start() | |
| # Wait for completion | |
| producer.join() | |
| for worker in ollama_workers: | |
| worker.join() | |
| stop_checkpoint.set() | |
| stop_progress.set() | |
| elapsed = time.time() - start_time | |
| # Save final results | |
| with open(OUTPUT_FILE, 'w') as f: | |
| json.dump(results, f, indent=2) | |
| # Clean up checkpoint | |
| if os.path.exists(CHECKPOINT_FILE): | |
| os.remove(CHECKPOINT_FILE) | |
| print("\n" + "=" * 70) | |
| print("COMPLETE") | |
| print("=" * 70) | |
| print(f"Total scored: {len(results)}") | |
| for endpoint in available_endpoints: | |
| name = endpoint['name'] | |
| count = stats.get(f'{name}_scored', 0) | |
| percentage = (count / len(results) * 100) if len(results) > 0 else 0 | |
| print(f"{name} scored: {count} ({percentage:.1f}%)") | |
| print(f"Time: {elapsed / 3600:.1f} hours") | |
| if elapsed > 0: | |
| rate = len(results) / elapsed | |
| print(f"Average rate: {rate:.2f} functions/second") | |
| print(f"Results saved to: {OUTPUT_FILE}") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment