add experiment scripts and results; watcher.py latest changes
This commit is contained in:
@@ -0,0 +1,230 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Experiment 003 — Entity-Only Consistency Test
|
||||
|
||||
Three Mistral passes per document, measure consistency on entity fields only
|
||||
(people, organizations, locations, dates). Excludes document_type label.
|
||||
DISTINCT ON (source) sampling — fixes Exp 001 chunk-replacement flaw.
|
||||
|
||||
Outputs: ~/aaronai/experiments/consistency_test_v2_results.json
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import psycopg2
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(Path.home() / "aaronai" / ".env")
|
||||
|
||||
OUTPUT_FILE = Path.home() / "aaronai" / "experiments" / "consistency_test_v2_results.json"
|
||||
OLLAMA_URL = "http://localhost:11434/api/generate"
|
||||
MODEL = "mistral"
|
||||
N_PASSES = 3
|
||||
N_DOCS = 50
|
||||
PER_CALL_TIMEOUT = 60 # seconds — fail fast, don't wedge
|
||||
MAX_DOC_CHARS = 8000 # cap document length sent to Mistral
|
||||
|
||||
EXTRACTION_PROMPT = """Extract entities from the document below. Return ONLY valid JSON with this exact schema:
|
||||
{
|
||||
"people": [string],
|
||||
"organizations": [string],
|
||||
"locations": [string],
|
||||
"dates": [string]
|
||||
}
|
||||
Rules:
|
||||
- Only include entities you are CERTAIN about. If uncertain, omit.
|
||||
- No prose, no markdown fences, no commentary. JSON only.
|
||||
- Empty arrays are valid.
|
||||
|
||||
DOCUMENT:
|
||||
"""
|
||||
|
||||
|
||||
def call_mistral(document_text):
|
||||
truncated = document_text[:MAX_DOC_CHARS]
|
||||
t0 = time.time()
|
||||
try:
|
||||
resp = requests.post(
|
||||
OLLAMA_URL,
|
||||
json={
|
||||
"model": MODEL,
|
||||
"prompt": EXTRACTION_PROMPT + truncated,
|
||||
"stream": False,
|
||||
"format": "json",
|
||||
"options": {"num_predict": 512},
|
||||
},
|
||||
timeout=PER_CALL_TIMEOUT,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return {
|
||||
"response": resp.json().get("response", ""),
|
||||
"latency_s": round(time.time() - t0, 2),
|
||||
"truncated": len(document_text) > MAX_DOC_CHARS,
|
||||
}
|
||||
except requests.exceptions.Timeout:
|
||||
return {"error": f"timeout after {PER_CALL_TIMEOUT}s", "latency_s": PER_CALL_TIMEOUT}
|
||||
except Exception as e:
|
||||
return {"error": str(e), "latency_s": round(time.time() - t0, 2)}
|
||||
|
||||
|
||||
def parse_entities(raw_response):
|
||||
text = (raw_response or "").strip()
|
||||
text = re.sub(r"^```(?:json)?\s*", "", text)
|
||||
text = re.sub(r"\s*```$", "", text)
|
||||
try:
|
||||
data = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
out = {}
|
||||
for key in ("people", "organizations", "locations", "dates"):
|
||||
vals = data.get(key, [])
|
||||
if not isinstance(vals, list):
|
||||
return None
|
||||
out[key] = sorted(set(str(v).strip().lower() for v in vals if v))
|
||||
return out
|
||||
|
||||
|
||||
def entities_match(a, b):
|
||||
if a is None or b is None:
|
||||
return False
|
||||
return all(a[k] == b[k] for k in ("people", "organizations", "locations", "dates"))
|
||||
|
||||
|
||||
def fetch_distinct_sources(pg_conn, n):
|
||||
cur = pg_conn.cursor()
|
||||
cur.execute("""
|
||||
SELECT source, string_agg(document, E'\n\n' ORDER BY id) AS doc
|
||||
FROM embeddings
|
||||
WHERE source IS NOT NULL
|
||||
GROUP BY source
|
||||
ORDER BY MIN(id)
|
||||
LIMIT %s
|
||||
""", (n,))
|
||||
rows = cur.fetchall()
|
||||
cur.close()
|
||||
return [(s, d) for s, d in rows if d and len(d.strip()) > 50]
|
||||
|
||||
|
||||
def main():
|
||||
pg_dsn = os.environ.get("PG_DSN")
|
||||
if not pg_dsn:
|
||||
print("ERROR: PG_DSN not set", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
pg_conn = psycopg2.connect(pg_dsn)
|
||||
docs = fetch_distinct_sources(pg_conn, N_DOCS)
|
||||
pg_conn.close()
|
||||
|
||||
print(f"Loaded {len(docs)} distinct sources from pgvector")
|
||||
print(f"Model: {MODEL} | Passes per doc: {N_PASSES}")
|
||||
print(f"Per-call timeout: {PER_CALL_TIMEOUT}s | Max doc chars: {MAX_DOC_CHARS}")
|
||||
print(f"Calls planned: {len(docs) * N_PASSES}\n")
|
||||
|
||||
results = []
|
||||
started_at = datetime.now(timezone.utc).isoformat()
|
||||
t_total = time.time()
|
||||
|
||||
for i, (source, doc_text) in enumerate(docs, 1):
|
||||
size_marker = f"[{len(doc_text):>5}c]"
|
||||
print(f"[{i:02d}/{len(docs)}] {size_marker} {source[:55]}", flush=True)
|
||||
passes = []
|
||||
for p in range(N_PASSES):
|
||||
r = call_mistral(doc_text)
|
||||
if "error" in r:
|
||||
print(f" pass {p+1}: {r['error']}", flush=True)
|
||||
passes.append({"error": r["error"], "parsed_ok": False, "latency_s": r["latency_s"]})
|
||||
else:
|
||||
entities = parse_entities(r["response"])
|
||||
passes.append({
|
||||
"raw": r["response"][:500],
|
||||
"entities": entities,
|
||||
"latency_s": r["latency_s"],
|
||||
"parsed_ok": entities is not None,
|
||||
"truncated_input": r.get("truncated", False),
|
||||
})
|
||||
|
||||
all_parsed = all(p.get("parsed_ok") for p in passes)
|
||||
if all_parsed:
|
||||
e1, e2, e3 = passes[0]["entities"], passes[1]["entities"], passes[2]["entities"]
|
||||
consistent = entities_match(e1, e2) and entities_match(e2, e3)
|
||||
per_field = {
|
||||
k: (e1[k] == e2[k] == e3[k])
|
||||
for k in ("people", "organizations", "locations", "dates")
|
||||
}
|
||||
else:
|
||||
consistent = False
|
||||
per_field = None
|
||||
|
||||
latencies = [p.get("latency_s", 0) for p in passes]
|
||||
print(f" parsed={all_parsed} consistent={consistent} latencies={latencies}", flush=True)
|
||||
|
||||
results.append({
|
||||
"source": source,
|
||||
"doc_chars": len(doc_text),
|
||||
"passes": passes,
|
||||
"all_parsed": all_parsed,
|
||||
"consistent": consistent,
|
||||
"per_field_consistency": per_field,
|
||||
})
|
||||
|
||||
total_elapsed = round(time.time() - t_total, 1)
|
||||
|
||||
parsed = [r for r in results if r["all_parsed"]]
|
||||
consistent = [r for r in parsed if r["consistent"]]
|
||||
|
||||
field_rates = {k: 0 for k in ("people", "organizations", "locations", "dates")}
|
||||
for r in parsed:
|
||||
for k, v in (r["per_field_consistency"] or {}).items():
|
||||
if v:
|
||||
field_rates[k] += 1
|
||||
field_rates_pct = {
|
||||
k: round(100 * v / len(parsed), 1) if parsed else 0.0
|
||||
for k, v in field_rates.items()
|
||||
}
|
||||
|
||||
summary = {
|
||||
"experiment": "003",
|
||||
"title": "Entity-Only Consistency Test",
|
||||
"started_at": started_at,
|
||||
"completed_at": datetime.now(timezone.utc).isoformat(),
|
||||
"model": MODEL,
|
||||
"n_passes": N_PASSES,
|
||||
"per_call_timeout_s": PER_CALL_TIMEOUT,
|
||||
"max_doc_chars": MAX_DOC_CHARS,
|
||||
"n_documents": len(docs),
|
||||
"n_all_parsed": len(parsed),
|
||||
"n_fully_consistent": len(consistent),
|
||||
"consistency_rate_pct": round(100 * len(consistent) / len(docs), 2) if docs else 0.0,
|
||||
"consistency_rate_among_parsed_pct": (
|
||||
round(100 * len(consistent) / len(parsed), 2) if parsed else 0.0
|
||||
),
|
||||
"per_field_consistency_pct": field_rates_pct,
|
||||
"total_elapsed_s": total_elapsed,
|
||||
"exp_001_baseline_pct": 18.0,
|
||||
"results": results,
|
||||
}
|
||||
|
||||
OUTPUT_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(OUTPUT_FILE, "w") as f:
|
||||
json.dump(summary, f, indent=2)
|
||||
|
||||
print()
|
||||
print("=" * 60)
|
||||
print(f"DONE — {len(docs)} docs in {total_elapsed}s")
|
||||
print(f"All 3 passes parsed cleanly: {len(parsed)}/{len(docs)}")
|
||||
print(f"Fully consistent (all 4 fields match): {len(consistent)}/{len(docs)} ({summary['consistency_rate_pct']}%)")
|
||||
print(f"Among parsed only: {summary['consistency_rate_among_parsed_pct']}%")
|
||||
print(f"Per-field consistency: {field_rates_pct}")
|
||||
print(f"Exp 001 baseline: 18% | delta: {summary['consistency_rate_pct'] - 18.0:+.2f} pts")
|
||||
print(f"Results: {OUTPUT_FILE}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user