experiments: add consistency test and briefing generator results + scripts

This commit is contained in:
2026-04-28 02:47:41 +00:00
parent 9937abbe27
commit b6fe350ab2
6 changed files with 6985 additions and 0 deletions
+248
View File
@@ -0,0 +1,248 @@
#!/usr/bin/env python3
"""
BirdAI Cascaded Extraction — Consistency Test
"""
import json
import os
import urllib.request
import urllib.error
import psycopg2
import psycopg2.extras
import hashlib
import time
from datetime import datetime
from dotenv import load_dotenv
load_dotenv(os.path.expanduser("~/aaronai/.env"))
PG_DSN = os.getenv("PG_DSN")
RESULTS_FILE = os.path.expanduser("~/aaronai/consistency_test_results.json")
MODEL = "mistral"
PASSES = 3
SAMPLE_SIZE = 50
OLLAMA_URL = "http://localhost:11434/api/generate"
EXTRACTION_PROMPT = """Extract named entities from this text. Return JSON only, no explanation, no prose.
Use exactly these fields (omit any field you are uncertain about, use empty list if none found):
{
"people": [],
"organizations": [],
"locations": [],
"dates": [],
"document_type": ""
}
Rules:
- Every value in people, organizations, locations, dates must be a plain string
- document_type must be a plain string
- No nested objects, no nested lists
- Only include entities you are certain about
- If uncertain about anything, omit it
Text: """
def get_sample_documents():
conn = psycopg2.connect(PG_DSN)
cur = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
cur.execute("""
SELECT id, document, source, created_at
FROM embeddings
WHERE length(document) > 100
AND length(document) < 3000
ORDER BY random()
LIMIT %s
""", (SAMPLE_SIZE,))
docs = cur.fetchall()
cur.close()
conn.close()
return docs
def run_extraction(text):
prompt = EXTRACTION_PROMPT + text[:1500]
payload = json.dumps({
"model": MODEL,
"prompt": prompt,
"stream": False
}).encode()
try:
req = urllib.request.Request(
OLLAMA_URL,
data=payload,
headers={"Content-Type": "application/json"}
)
with urllib.request.urlopen(req, timeout=180) as resp:
result = json.loads(resp.read().decode())
raw = result.get("response", "").strip()
start = raw.find("{")
end = raw.rfind("}") + 1
if start == -1 or end == 0:
return None, f"NO_JSON: {raw[:100]}"
json_str = raw[start:end]
parsed = json.loads(json_str)
if not isinstance(parsed, dict):
return None, f"NOT_DICT: {json_str[:100]}"
return parsed, raw
except urllib.error.URLError as e:
return None, f"URL_ERROR: {e}"
except TimeoutError:
return None, "TIMEOUT"
except json.JSONDecodeError as e:
return None, f"JSON_ERROR: {e}"
except Exception as e:
return None, f"ERROR: {type(e).__name__}: {e}"
def flatten_value(v):
if isinstance(v, str):
return v.lower().strip()
elif isinstance(v, dict):
return json.dumps(v, sort_keys=True).lower()
elif isinstance(v, list):
return json.dumps(sorted([flatten_value(i) for i in v]))
else:
return str(v).lower().strip()
def normalize_extraction(extracted):
if extracted is None:
return None
normalized = {}
expected_fields = ["people", "organizations", "locations", "dates", "document_type"]
for key in expected_fields:
val = extracted.get(key, [] if key != "document_type" else "")
if isinstance(val, list):
normalized[key] = sorted([flatten_value(v) for v in val])
else:
normalized[key] = flatten_value(val)
return normalized
def extractions_consistent(extractions):
if any(e is None for e in extractions):
return False
normalized = [normalize_extraction(e) for e in extractions]
if any(n is None for n in normalized):
return False
return all(n == normalized[0] for n in normalized[1:])
def content_hash(text):
return hashlib.md5(text.encode()).hexdigest()[:8]
def main():
print(f"\nBirdAI Consistency Test")
print(f"Model: {MODEL} | Passes: {PASSES} | Sample: {SAMPLE_SIZE} docs")
print(f"Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Results: {RESULTS_FILE}")
print("-" * 60)
docs = get_sample_documents()
print(f"Loaded {len(docs)} documents from pgvector\n")
results = {
"meta": {
"model": MODEL,
"passes": PASSES,
"sample_size": len(docs),
"started": datetime.now().isoformat(),
"completed": None
},
"documents": [],
"summary": {}
}
consistent_count = 0
failed_count = 0
timeout_count = 0
for i, doc in enumerate(docs):
doc_id = doc["id"]
content = doc["document"]
source = doc.get("source", "unknown")
chash = content_hash(content)
print(f"[{i+1:02d}/{len(docs)}] {source[:50]:<50} hash:{chash}", end=" ", flush=True)
passes = []
pass_times = []
raw_outputs = []
for p in range(PASSES):
t_start = time.time()
extracted, raw = run_extraction(content)
t_end = time.time()
passes.append(extracted)
pass_times.append(round(t_end - t_start, 1))
raw_outputs.append(raw[:200] if raw else "")
consistent = extractions_consistent(passes)
any_timeout = any("TIMEOUT" in str(r) for r in raw_outputs)
any_failed = any(p is None for p in passes)
if any_timeout:
timeout_count += 1
status = "TIMEOUT"
elif any_failed:
failed_count += 1
status = "FAILED"
elif consistent:
consistent_count += 1
status = "CONSISTENT"
else:
status = "INCONSISTENT"
print(f"{status} ({'/'.join(str(t) for t in pass_times)}s)")
try:
sample_extraction = normalize_extraction(passes[0]) if passes[0] else None
except Exception:
sample_extraction = None
results["documents"].append({
"id": doc_id,
"source": source,
"content_hash": chash,
"content_length": len(content),
"status": status,
"consistent": consistent,
"pass_times_seconds": pass_times,
"extraction_sample": sample_extraction,
"raw_samples": raw_outputs
})
with open(RESULTS_FILE, "w") as f:
json.dump(results, f, indent=2, default=str)
total = len(docs)
completed_at = datetime.now().isoformat()
results["meta"]["completed"] = completed_at
summary = {
"total": total,
"consistent": consistent_count,
"inconsistent": total - consistent_count - failed_count - timeout_count,
"failed": failed_count,
"timeout": timeout_count,
"consistency_rate": round(consistent_count / total * 100, 1),
"cascade_viable": consistent_count / total >= 0.5
}
results["summary"] = summary
with open(RESULTS_FILE, "w") as f:
json.dump(results, f, indent=2, default=str)
print("\n" + "=" * 60)
print(f"RESULTS")
print(f" Consistent: {consistent_count}/{total} ({summary['consistency_rate']}%)")
print(f" Inconsistent: {summary['inconsistent']}")
print(f" Failed/Timeout: {failed_count + timeout_count}")
print(f" Cascade viable: {'YES' if summary['cascade_viable'] else 'NO — reconsider architecture'}")
print(f" Completed: {completed_at}")
print(f" Full results: {RESULTS_FILE}")
print("=" * 60)
if __name__ == "__main__":
main()