#!/usr/bin/env python3 """ Experiment 003 — Entity-Only Consistency Test Three Mistral passes per document, measure consistency on entity fields only (people, organizations, locations, dates). Excludes document_type label. DISTINCT ON (source) sampling — fixes Exp 001 chunk-replacement flaw. Outputs: ~/aaronai/experiments/consistency_test_v2_results.json """ import json import os import re import sys import time from datetime import datetime, timezone from pathlib import Path import psycopg2 import requests from dotenv import load_dotenv load_dotenv(Path.home() / "aaronai" / ".env") OUTPUT_FILE = Path.home() / "aaronai" / "experiments" / "consistency_test_v2_results.json" OLLAMA_URL = "http://localhost:11434/api/generate" MODEL = "mistral" N_PASSES = 3 N_DOCS = 50 PER_CALL_TIMEOUT = 60 # seconds — fail fast, don't wedge MAX_DOC_CHARS = 8000 # cap document length sent to Mistral EXTRACTION_PROMPT = """Extract entities from the document below. Return ONLY valid JSON with this exact schema: { "people": [string], "organizations": [string], "locations": [string], "dates": [string] } Rules: - Only include entities you are CERTAIN about. If uncertain, omit. - No prose, no markdown fences, no commentary. JSON only. - Empty arrays are valid. DOCUMENT: """ def call_mistral(document_text): truncated = document_text[:MAX_DOC_CHARS] t0 = time.time() try: resp = requests.post( OLLAMA_URL, json={ "model": MODEL, "prompt": EXTRACTION_PROMPT + truncated, "stream": False, "format": "json", "options": {"num_predict": 512}, }, timeout=PER_CALL_TIMEOUT, ) resp.raise_for_status() return { "response": resp.json().get("response", ""), "latency_s": round(time.time() - t0, 2), "truncated": len(document_text) > MAX_DOC_CHARS, } except requests.exceptions.Timeout: return {"error": f"timeout after {PER_CALL_TIMEOUT}s", "latency_s": PER_CALL_TIMEOUT} except Exception as e: return {"error": str(e), "latency_s": round(time.time() - t0, 2)} def parse_entities(raw_response): text = (raw_response or "").strip() text = re.sub(r"^```(?:json)?\s*", "", text) text = re.sub(r"\s*```$", "", text) try: data = json.loads(text) except json.JSONDecodeError: return None out = {} for key in ("people", "organizations", "locations", "dates"): vals = data.get(key, []) if not isinstance(vals, list): return None out[key] = sorted(set(str(v).strip().lower() for v in vals if v)) return out def entities_match(a, b): if a is None or b is None: return False return all(a[k] == b[k] for k in ("people", "organizations", "locations", "dates")) def fetch_distinct_sources(pg_conn, n): cur = pg_conn.cursor() cur.execute(""" SELECT source, string_agg(document, E'\n\n' ORDER BY id) AS doc FROM embeddings WHERE source IS NOT NULL GROUP BY source ORDER BY MIN(id) LIMIT %s """, (n,)) rows = cur.fetchall() cur.close() return [(s, d) for s, d in rows if d and len(d.strip()) > 50] def main(): pg_dsn = os.environ.get("PG_DSN") if not pg_dsn: print("ERROR: PG_DSN not set", file=sys.stderr) sys.exit(1) pg_conn = psycopg2.connect(pg_dsn) docs = fetch_distinct_sources(pg_conn, N_DOCS) pg_conn.close() print(f"Loaded {len(docs)} distinct sources from pgvector") print(f"Model: {MODEL} | Passes per doc: {N_PASSES}") print(f"Per-call timeout: {PER_CALL_TIMEOUT}s | Max doc chars: {MAX_DOC_CHARS}") print(f"Calls planned: {len(docs) * N_PASSES}\n") results = [] started_at = datetime.now(timezone.utc).isoformat() t_total = time.time() for i, (source, doc_text) in enumerate(docs, 1): size_marker = f"[{len(doc_text):>5}c]" print(f"[{i:02d}/{len(docs)}] {size_marker} {source[:55]}", flush=True) passes = [] for p in range(N_PASSES): r = call_mistral(doc_text) if "error" in r: print(f" pass {p+1}: {r['error']}", flush=True) passes.append({"error": r["error"], "parsed_ok": False, "latency_s": r["latency_s"]}) else: entities = parse_entities(r["response"]) passes.append({ "raw": r["response"][:500], "entities": entities, "latency_s": r["latency_s"], "parsed_ok": entities is not None, "truncated_input": r.get("truncated", False), }) all_parsed = all(p.get("parsed_ok") for p in passes) if all_parsed: e1, e2, e3 = passes[0]["entities"], passes[1]["entities"], passes[2]["entities"] consistent = entities_match(e1, e2) and entities_match(e2, e3) per_field = { k: (e1[k] == e2[k] == e3[k]) for k in ("people", "organizations", "locations", "dates") } else: consistent = False per_field = None latencies = [p.get("latency_s", 0) for p in passes] print(f" parsed={all_parsed} consistent={consistent} latencies={latencies}", flush=True) results.append({ "source": source, "doc_chars": len(doc_text), "passes": passes, "all_parsed": all_parsed, "consistent": consistent, "per_field_consistency": per_field, }) total_elapsed = round(time.time() - t_total, 1) parsed = [r for r in results if r["all_parsed"]] consistent = [r for r in parsed if r["consistent"]] field_rates = {k: 0 for k in ("people", "organizations", "locations", "dates")} for r in parsed: for k, v in (r["per_field_consistency"] or {}).items(): if v: field_rates[k] += 1 field_rates_pct = { k: round(100 * v / len(parsed), 1) if parsed else 0.0 for k, v in field_rates.items() } summary = { "experiment": "003", "title": "Entity-Only Consistency Test", "started_at": started_at, "completed_at": datetime.now(timezone.utc).isoformat(), "model": MODEL, "n_passes": N_PASSES, "per_call_timeout_s": PER_CALL_TIMEOUT, "max_doc_chars": MAX_DOC_CHARS, "n_documents": len(docs), "n_all_parsed": len(parsed), "n_fully_consistent": len(consistent), "consistency_rate_pct": round(100 * len(consistent) / len(docs), 2) if docs else 0.0, "consistency_rate_among_parsed_pct": ( round(100 * len(consistent) / len(parsed), 2) if parsed else 0.0 ), "per_field_consistency_pct": field_rates_pct, "total_elapsed_s": total_elapsed, "exp_001_baseline_pct": 18.0, "results": results, } OUTPUT_FILE.parent.mkdir(parents=True, exist_ok=True) with open(OUTPUT_FILE, "w") as f: json.dump(summary, f, indent=2) print() print("=" * 60) print(f"DONE — {len(docs)} docs in {total_elapsed}s") print(f"All 3 passes parsed cleanly: {len(parsed)}/{len(docs)}") print(f"Fully consistent (all 4 fields match): {len(consistent)}/{len(docs)} ({summary['consistency_rate_pct']}%)") print(f"Among parsed only: {summary['consistency_rate_among_parsed_pct']}%") print(f"Per-field consistency: {field_rates_pct}") print(f"Exp 001 baseline: 18% | delta: {summary['consistency_rate_pct'] - 18.0:+.2f} pts") print(f"Results: {OUTPUT_FILE}") if __name__ == "__main__": main()