#!/usr/bin/env python3 """ BirdAI Cascaded Extraction — Consistency Test """ import json import os import urllib.request import urllib.error import psycopg2 import psycopg2.extras import hashlib import time from datetime import datetime from dotenv import load_dotenv load_dotenv(os.path.expanduser("~/aaronai/.env")) PG_DSN = os.getenv("PG_DSN") RESULTS_FILE = os.path.expanduser("~/aaronai/consistency_test_results.json") MODEL = "mistral" PASSES = 3 SAMPLE_SIZE = 50 OLLAMA_URL = "http://localhost:11434/api/generate" EXTRACTION_PROMPT = """Extract named entities from this text. Return JSON only, no explanation, no prose. Use exactly these fields (omit any field you are uncertain about, use empty list if none found): { "people": [], "organizations": [], "locations": [], "dates": [], "document_type": "" } Rules: - Every value in people, organizations, locations, dates must be a plain string - document_type must be a plain string - No nested objects, no nested lists - Only include entities you are certain about - If uncertain about anything, omit it Text: """ def get_sample_documents(): conn = psycopg2.connect(PG_DSN) cur = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) cur.execute(""" SELECT id, document, source, created_at FROM embeddings WHERE length(document) > 100 AND length(document) < 3000 ORDER BY random() LIMIT %s """, (SAMPLE_SIZE,)) docs = cur.fetchall() cur.close() conn.close() return docs def run_extraction(text): prompt = EXTRACTION_PROMPT + text[:1500] payload = json.dumps({ "model": MODEL, "prompt": prompt, "stream": False }).encode() try: req = urllib.request.Request( OLLAMA_URL, data=payload, headers={"Content-Type": "application/json"} ) with urllib.request.urlopen(req, timeout=180) as resp: result = json.loads(resp.read().decode()) raw = result.get("response", "").strip() start = raw.find("{") end = raw.rfind("}") + 1 if start == -1 or end == 0: return None, f"NO_JSON: {raw[:100]}" json_str = raw[start:end] parsed = json.loads(json_str) if not isinstance(parsed, dict): return None, f"NOT_DICT: {json_str[:100]}" return parsed, raw except urllib.error.URLError as e: return None, f"URL_ERROR: {e}" except TimeoutError: return None, "TIMEOUT" except json.JSONDecodeError as e: return None, f"JSON_ERROR: {e}" except Exception as e: return None, f"ERROR: {type(e).__name__}: {e}" def flatten_value(v): if isinstance(v, str): return v.lower().strip() elif isinstance(v, dict): return json.dumps(v, sort_keys=True).lower() elif isinstance(v, list): return json.dumps(sorted([flatten_value(i) for i in v])) else: return str(v).lower().strip() def normalize_extraction(extracted): if extracted is None: return None normalized = {} expected_fields = ["people", "organizations", "locations", "dates", "document_type"] for key in expected_fields: val = extracted.get(key, [] if key != "document_type" else "") if isinstance(val, list): normalized[key] = sorted([flatten_value(v) for v in val]) else: normalized[key] = flatten_value(val) return normalized def extractions_consistent(extractions): if any(e is None for e in extractions): return False normalized = [normalize_extraction(e) for e in extractions] if any(n is None for n in normalized): return False return all(n == normalized[0] for n in normalized[1:]) def content_hash(text): return hashlib.md5(text.encode()).hexdigest()[:8] def main(): print(f"\nBirdAI Consistency Test") print(f"Model: {MODEL} | Passes: {PASSES} | Sample: {SAMPLE_SIZE} docs") print(f"Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") print(f"Results: {RESULTS_FILE}") print("-" * 60) docs = get_sample_documents() print(f"Loaded {len(docs)} documents from pgvector\n") results = { "meta": { "model": MODEL, "passes": PASSES, "sample_size": len(docs), "started": datetime.now().isoformat(), "completed": None }, "documents": [], "summary": {} } consistent_count = 0 failed_count = 0 timeout_count = 0 for i, doc in enumerate(docs): doc_id = doc["id"] content = doc["document"] source = doc.get("source", "unknown") chash = content_hash(content) print(f"[{i+1:02d}/{len(docs)}] {source[:50]:<50} hash:{chash}", end=" ", flush=True) passes = [] pass_times = [] raw_outputs = [] for p in range(PASSES): t_start = time.time() extracted, raw = run_extraction(content) t_end = time.time() passes.append(extracted) pass_times.append(round(t_end - t_start, 1)) raw_outputs.append(raw[:200] if raw else "") consistent = extractions_consistent(passes) any_timeout = any("TIMEOUT" in str(r) for r in raw_outputs) any_failed = any(p is None for p in passes) if any_timeout: timeout_count += 1 status = "TIMEOUT" elif any_failed: failed_count += 1 status = "FAILED" elif consistent: consistent_count += 1 status = "CONSISTENT" else: status = "INCONSISTENT" print(f"→ {status} ({'/'.join(str(t) for t in pass_times)}s)") try: sample_extraction = normalize_extraction(passes[0]) if passes[0] else None except Exception: sample_extraction = None results["documents"].append({ "id": doc_id, "source": source, "content_hash": chash, "content_length": len(content), "status": status, "consistent": consistent, "pass_times_seconds": pass_times, "extraction_sample": sample_extraction, "raw_samples": raw_outputs }) with open(RESULTS_FILE, "w") as f: json.dump(results, f, indent=2, default=str) total = len(docs) completed_at = datetime.now().isoformat() results["meta"]["completed"] = completed_at summary = { "total": total, "consistent": consistent_count, "inconsistent": total - consistent_count - failed_count - timeout_count, "failed": failed_count, "timeout": timeout_count, "consistency_rate": round(consistent_count / total * 100, 1), "cascade_viable": consistent_count / total >= 0.5 } results["summary"] = summary with open(RESULTS_FILE, "w") as f: json.dump(results, f, indent=2, default=str) print("\n" + "=" * 60) print(f"RESULTS") print(f" Consistent: {consistent_count}/{total} ({summary['consistency_rate']}%)") print(f" Inconsistent: {summary['inconsistent']}") print(f" Failed/Timeout: {failed_count + timeout_count}") print(f" Cascade viable: {'YES' if summary['cascade_viable'] else 'NO — reconsider architecture'}") print(f" Completed: {completed_at}") print(f" Full results: {RESULTS_FILE}") print("=" * 60) if __name__ == "__main__": main()