Skip to content

Instantly share code, notes, and snippets.

@rsarv3006
Created January 16, 2026 21:57
Show Gist options
  • Select an option

  • Save rsarv3006/e5cbba4e00799fd399adb940fb2ff2e1 to your computer and use it in GitHub Desktop.

Select an option

Save rsarv3006/e5cbba4e00799fd399adb940fb2ff2e1 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
"""
Distributed Zig function scoring with Ollama endpoints.
Scores Zig functions for quality, complexity, and production-readiness.
"""
import requests
from typing import Dict, List, Any
import time
import queue
import threading
import sys
import json
import os
# Configuration
INPUT_FILE = "extracted_functions.txt"
OUTPUT_FILE = "scored_zig_functions.json"
CHECKPOINT_FILE = ".scored_functions_checkpoint.json"
# Ollama endpoints - add as many as you want!
OLLAMA_ENDPOINTS = [
# {"name": "P40-Local", "url": "http://localhost:11434/api/generate", "workers": 2},
{"name": "MacBook-22", "url": "http://10.0.0.22:11434/api/generate", "workers": 2},
{"name": "MacBook-31", "url": "http://10.0.0.31:11434/api/generate", "workers": 2},
]
WORKERS_PER_ENDPOINT = 2
# Ollama model to use
OLLAMA_MODEL_PREFERENCES = [
"qwen2.5-coder:7b-instruct",
]
CHECKPOINT_INTERVAL = 200
WORK_QUEUE_SIZE = 500
SCORING_PROMPT = """Rate this Zig function (0-100 points).
COMPLETENESS (0-25): Full implementation?
COMPLEXITY (0-20): Non-trivial logic?
REAL_WORLD (0-20): Production-worthy? (Error handling with !, allocator patterns, etc.)
CODE_QUALITY (0-20): Idiomatic Zig? (Proper error unions, memory management, etc.)
EDUCATIONAL (0-15): Teaches Zig concepts? (Comptime, optionals, slices, etc.)
Function:
```zig
{code}
```
JSON output only:
{{
"completeness": <0-25>,
"complexity": <0-20>,
"realWorld": <0-20>,
"quality": <0-20>,
"educational": <0-15>,
"overallScore": <sum>
}}
"""
def parse_zig_functions(file_path: str) -> List[Dict[str, Any]]:
"""Parse extracted_functions.txt."""
functions = []
current_func = {}
in_code = False
code_lines = []
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
for line in f:
line = line.rstrip('\n')
if line.startswith('===FUNCTION==='):
if current_func and code_lines:
current_func['code'] = '\n'.join(code_lines)
functions.append(current_func)
current_func = {}
code_lines = []
in_code = False
elif line.startswith('Name:'):
current_func['function_name'] = line.split(':', 1)[1].strip()
elif line.startswith('Repo:'):
current_func['repository'] = line.split(':', 1)[1].strip()
elif line.startswith('File:'):
current_func['file'] = line.split(':', 1)[1].strip()
elif line.startswith('Code:'):
in_code = True
elif in_code:
code_lines.append(line)
if current_func and code_lines:
current_func['code'] = '\n'.join(code_lines)
functions.append(current_func)
return functions
def regex_prefilter(functions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Fast regex-based filtering for Zig functions."""
print("Running regex pre-filter (no GPU)...")
filtered = []
rejected_test = 0
rejected_empty = 0
for func in functions:
code = func.get('code', '')
fname = func.get('function_name', '').lower()
# Reject test functions
if any([
fname.startswith('test'),
fname.startswith('benchmark'),
'test.zig' in func.get('file', ''),
]):
rejected_test += 1
continue
# Reject empty functions
lines = [l.strip() for l in code.split('\n') if l.strip()
and not l.strip().startswith('//')]
func_body = [l for l in lines if not l.startswith(
'fn') and not l.startswith('pub fn') and l not in ['{', '}']]
if len(func_body) == 0:
rejected_empty += 1
continue
# Removed trivial filter - let the model score short functions
# Some elegant algorithms are concise (like LeetCode solutions)
filtered.append(func)
print(f"✓ Regex pre-filter complete:")
print(f" Input: {len(functions)}")
print(f" Rejected (test): {rejected_test}")
print(f" Rejected (empty): {rejected_empty}")
print(f" Remaining: {len(filtered)} (let model score short functions)")
print()
return filtered
def check_ollama_endpoint(endpoint: Dict[str, str]) -> bool:
"""Check if an Ollama endpoint is reachable."""
try:
tags_url = endpoint["url"].replace("/api/generate", "/api/tags")
print(f" [{endpoint['name']}] Checking {
tags_url}...", end=" ", flush=True)
response = requests.get(tags_url, timeout=5)
if response.status_code == 200:
data = response.json()
models = [m["name"] for m in data.get("models", [])]
for model in OLLAMA_MODEL_PREFERENCES:
if model in models:
endpoint["model"] = model
print(f"✓ (using {model})")
return True
print(f"\n⚠️ None of the preferred models found")
print(f" Available: {', '.join(models[:5])}")
return False
else:
print(f"\n⚠️ HTTP {response.status_code}")
return False
except Exception as e:
print(f"\n⚠️ Connection failed: {str(e)[:100]}")
return False
def call_ollama(code: str, endpoint_url: str, model: str) -> Dict[str, Any]:
"""Score function using Ollama."""
if len(code) > 1000:
code = code[:1000] + "\n... (truncated)"
prompt = SCORING_PROMPT.format(code=code)
response = requests.post(
endpoint_url,
json={
"model": model,
"prompt": prompt,
"stream": False,
"system": "You are a Zig code quality evaluator. Score each criterion independently and respond with valid JSON only.",
"options": {
"temperature": 0.7,
"num_predict": 100,
}
},
timeout=300
)
response.raise_for_status()
result = response.json()
text = result.get("response", "")
# Parse JSON
text = text.strip()
if text.startswith("```"):
lines = text.split("\n")
text = "\n".join(lines[1:-1]) if len(lines) > 2 else text
start_idx = text.find("{")
end_idx = text.rfind("}") + 1
if start_idx != -1 and end_idx > start_idx:
json_str = text[start_idx:end_idx]
return json.loads(json_str)
else:
raise ValueError(f"No JSON found in Ollama response")
def producer_worker(functions: List[Dict[str, Any]], work_queue: queue.Queue, stats: Dict, num_workers: int):
"""Producer: Feed functions into work queue."""
print("[PRODUCER] Starting...")
for i, func in enumerate(functions):
work_queue.put(func)
if (i + 1) % 1000 == 0:
print(f"[PRODUCER] Queued {i + 1}/{len(functions)}")
# Signal all workers to stop
for _ in range(num_workers):
work_queue.put(None)
stats['queued'] = len(functions)
print(f"[PRODUCER] Complete! Queued {len(functions)} functions")
def ollama_scorer_worker(endpoint: Dict[str, str], work_queue: queue.Queue, results: List,
results_lock: threading.Lock, stats: Dict, stats_lock: threading.Lock):
"""Ollama scorer: Pull from queue and score with Ollama endpoint."""
name = endpoint["name"]
url = endpoint["url"]
model = endpoint.get("model", OLLAMA_MODEL_PREFERENCES[0])
print(f"[{name}] Scorer starting...")
scored_count = 0
worker_start = time.time()
while True:
func = work_queue.get()
if func is None:
work_queue.task_done()
break
retry_count = func.get('_retry_count', 0)
try:
score_data = call_ollama(func['code'], url, model)
scored_func = func.copy()
scored_func.pop('_retry_count', None)
scored_func["overallScore"] = int(score_data["overallScore"])
scored_func["scoreBreakdown"] = {
"completeness": score_data.get("completeness", 0),
"complexity": score_data.get("complexity", 0),
"realWorld": score_data.get("realWorld", 0),
"quality": score_data.get("quality", 0),
"educational": score_data.get("educational", 0)
}
scored_func["scorer"] = name
with results_lock:
results.append(scored_func)
scored_count += 1
with stats_lock:
stats[f'{name}_scored'] = scored_count
if scored_count % 100 == 0:
worker_elapsed = time.time() - worker_start
rate = scored_count / worker_elapsed
total_scored = len(results)
print(f"[{name}] {scored_count} scored | {
rate:.2f}/s | Total: {total_scored:,}")
except Exception as e:
func['_retry_count'] = retry_count + 1
work_queue.put(func)
if retry_count > 0 and retry_count % 5 == 0:
print(f"[{name}] Error scoring {func['function_name']
}, retry {retry_count + 1}: {str(e)[:100]}")
work_queue.task_done()
with stats_lock:
stats[f'{name}_scored'] = scored_count
print(f"[{name}] Complete! Scored {scored_count} functions")
def checkpoint_worker(results: List, results_lock: threading.Lock, stop_event: threading.Event):
"""Periodic checkpoint saver."""
while not stop_event.is_set():
time.sleep(300) # Every 5 minutes
with results_lock:
if results:
with open(CHECKPOINT_FILE, 'w') as f:
json.dump({"scored_functions": results}, f, indent=2)
print(f"[CHECKPOINT] Saved {len(results)} results")
def progress_tracker(results: List, results_lock: threading.Lock, stats: Dict,
stats_lock: threading.Lock, total_functions: int,
start_time: float, stop_event: threading.Event, initial_count: int):
"""Track and display progress stats."""
last_count = 0
last_update_time = start_time
session_start_count = initial_count
while not stop_event.is_set():
time.sleep(60) # Update every 60 seconds
with results_lock:
current_count = len(results)
if current_count == 0:
continue
elapsed = time.time() - start_time
interval = time.time() - last_update_time
completed = current_count
session_completed = completed - session_start_count
remaining = total_functions - completed
session_rate = session_completed / elapsed if elapsed > 0 else 0
recent_rate = (current_count - last_count) / \
interval if interval > 0 else 0
rate_for_eta = recent_rate if recent_rate > 0 else session_rate
eta_seconds = remaining / rate_for_eta if rate_for_eta > 0 else 0
eta_hours = eta_seconds / 3600
with stats_lock:
worker_stats = [(k.replace('_scored', ''), v)
for k, v in stats.items() if k.endswith('_scored')]
worker_stats.sort(key=lambda x: x[1], reverse=True)
print(f"\n{'='*70}")
print(f"PROGRESS: {
completed:,}/{total_functions:,} ({completed/total_functions*100:.1f}%)")
print(f"{'='*70}")
print(f"Session: {session_completed:,} functions in {
elapsed/3600:.1f}h | {session_rate:.2f}/s avg")
print(f"ETA: {eta_hours:.1f}h | Throughput: {
recent_rate:.2f}/s recent")
print(f"\nPer-worker stats:")
for name, count in worker_stats:
pct = (count / completed * 100) if completed > 0 else 0
worker_rate = count / elapsed if elapsed > 0 else 0
print(f" {name:20s} {count:6,} ({pct:5.1f}%) | {
worker_rate:.2f}/s")
print(f"{'='*70}\n")
last_count = current_count
last_update_time = time.time()
def main():
if not os.path.exists(INPUT_FILE):
print(f"Error: {INPUT_FILE} not found")
sys.exit(1)
print("=" * 70)
print("Distributed Zig Function Scoring")
print("=" * 70)
# Check which Ollama endpoints are available
print("Checking Ollama endpoints...")
available_endpoints = []
for endpoint in OLLAMA_ENDPOINTS:
if check_ollama_endpoint(endpoint):
available_endpoints.append(endpoint)
if not available_endpoints:
print("\n❌ No Ollama endpoints available!")
sys.exit(1)
print(f"\n{len(available_endpoints)} Ollama endpoint(s) available")
print("=" * 70)
print()
# Parse and filter
print("Parsing functions...")
functions = parse_zig_functions(INPUT_FILE)
print(f"Found {len(functions)} functions")
print()
functions = regex_prefilter(functions)
if len(functions) == 0:
print("No functions remaining after filter!")
sys.exit(1)
# Check for checkpoint and resume
existing_results = []
scored_function_keys = set()
original_total_functions = len(functions)
if os.path.exists(CHECKPOINT_FILE):
print("⚠️ Checkpoint found! Loading previous results...")
try:
with open(CHECKPOINT_FILE, 'r') as f:
checkpoint_data = json.load(f)
existing_results = checkpoint_data.get("scored_functions", [])
scored_function_keys = set(
(f.get("repository", ""), f.get(
"file", ""), f.get("function_name", ""))
for f in existing_results
)
print(f"✓ Loaded {len(existing_results)} already-scored functions")
functions = [
f for f in functions
if (f.get("repository", ""), f.get("file", ""), f.get("function_name", ""))
not in scored_function_keys
]
print(f"✓ {len(functions)} functions remaining to score")
if len(functions) == 0:
print("\n✓ All functions already scored!")
return
except Exception as e:
print(f"⚠️ Error loading checkpoint: {e}")
existing_results = []
# Calculate total workers
total_workers = sum(endpoint.get("workers", WORKERS_PER_ENDPOINT)
for endpoint in available_endpoints)
worker_breakdown = ", ".join([f"{ep['name']}:{ep.get(
'workers', WORKERS_PER_ENDPOINT)}" for ep in available_endpoints])
print(f"Starting distributed scoring on {len(functions)} functions...")
print(f"Workers: {total_workers} total ({worker_breakdown})")
print()
# Setup
work_queue = queue.Queue(maxsize=WORK_QUEUE_SIZE)
results = existing_results
results_lock = threading.Lock()
stats = {}
stats_lock = threading.Lock()
stop_checkpoint = threading.Event()
stop_progress = threading.Event()
# Start producer
producer = threading.Thread(
target=producer_worker,
args=(functions, work_queue, stats, total_workers),
daemon=True
)
# Start Ollama workers
ollama_workers = []
for endpoint in available_endpoints:
num_workers = endpoint.get("workers", WORKERS_PER_ENDPOINT)
for worker_id in range(num_workers):
endpoint_copy = endpoint.copy()
if num_workers > 1:
endpoint_copy["name"] = f"{endpoint['name']}-W{worker_id+1}"
worker = threading.Thread(
target=ollama_scorer_worker,
args=(endpoint_copy, work_queue, results,
results_lock, stats, stats_lock),
daemon=True
)
ollama_workers.append(worker)
checkpoint_thread = threading.Thread(
target=checkpoint_worker,
args=(results, results_lock, stop_checkpoint),
daemon=True
)
progress_thread = threading.Thread(
target=progress_tracker,
args=(results, results_lock, stats, stats_lock, original_total_functions,
time.time(), stop_progress, len(existing_results)),
daemon=True
)
start_time = time.time()
producer.start()
for worker in ollama_workers:
worker.start()
checkpoint_thread.start()
progress_thread.start()
# Wait for completion
producer.join()
for worker in ollama_workers:
worker.join()
stop_checkpoint.set()
stop_progress.set()
elapsed = time.time() - start_time
# Save final results
with open(OUTPUT_FILE, 'w') as f:
json.dump(results, f, indent=2)
# Clean up checkpoint
if os.path.exists(CHECKPOINT_FILE):
os.remove(CHECKPOINT_FILE)
print("\n" + "=" * 70)
print("COMPLETE")
print("=" * 70)
print(f"Total scored: {len(results)}")
for endpoint in available_endpoints:
name = endpoint['name']
count = stats.get(f'{name}_scored', 0)
percentage = (count / len(results) * 100) if len(results) > 0 else 0
print(f"{name} scored: {count} ({percentage:.1f}%)")
print(f"Time: {elapsed / 3600:.1f} hours")
if elapsed > 0:
rate = len(results) / elapsed
print(f"Average rate: {rate:.2f} functions/second")
print(f"Results saved to: {OUTPUT_FILE}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment