experiments: add consistency test and briefing generator results + scripts
This commit is contained in:
@@ -0,0 +1,248 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
BirdAI Cascaded Extraction — Consistency Test
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
import hashlib
|
||||
import time
|
||||
from datetime import datetime
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(os.path.expanduser("~/aaronai/.env"))
|
||||
|
||||
PG_DSN = os.getenv("PG_DSN")
|
||||
RESULTS_FILE = os.path.expanduser("~/aaronai/consistency_test_results.json")
|
||||
MODEL = "mistral"
|
||||
PASSES = 3
|
||||
SAMPLE_SIZE = 50
|
||||
OLLAMA_URL = "http://localhost:11434/api/generate"
|
||||
|
||||
EXTRACTION_PROMPT = """Extract named entities from this text. Return JSON only, no explanation, no prose.
|
||||
Use exactly these fields (omit any field you are uncertain about, use empty list if none found):
|
||||
{
|
||||
"people": [],
|
||||
"organizations": [],
|
||||
"locations": [],
|
||||
"dates": [],
|
||||
"document_type": ""
|
||||
}
|
||||
Rules:
|
||||
- Every value in people, organizations, locations, dates must be a plain string
|
||||
- document_type must be a plain string
|
||||
- No nested objects, no nested lists
|
||||
- Only include entities you are certain about
|
||||
- If uncertain about anything, omit it
|
||||
Text: """
|
||||
|
||||
|
||||
def get_sample_documents():
|
||||
conn = psycopg2.connect(PG_DSN)
|
||||
cur = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
|
||||
cur.execute("""
|
||||
SELECT id, document, source, created_at
|
||||
FROM embeddings
|
||||
WHERE length(document) > 100
|
||||
AND length(document) < 3000
|
||||
ORDER BY random()
|
||||
LIMIT %s
|
||||
""", (SAMPLE_SIZE,))
|
||||
docs = cur.fetchall()
|
||||
cur.close()
|
||||
conn.close()
|
||||
return docs
|
||||
|
||||
|
||||
def run_extraction(text):
|
||||
prompt = EXTRACTION_PROMPT + text[:1500]
|
||||
payload = json.dumps({
|
||||
"model": MODEL,
|
||||
"prompt": prompt,
|
||||
"stream": False
|
||||
}).encode()
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
OLLAMA_URL,
|
||||
data=payload,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=180) as resp:
|
||||
result = json.loads(resp.read().decode())
|
||||
raw = result.get("response", "").strip()
|
||||
start = raw.find("{")
|
||||
end = raw.rfind("}") + 1
|
||||
if start == -1 or end == 0:
|
||||
return None, f"NO_JSON: {raw[:100]}"
|
||||
json_str = raw[start:end]
|
||||
parsed = json.loads(json_str)
|
||||
if not isinstance(parsed, dict):
|
||||
return None, f"NOT_DICT: {json_str[:100]}"
|
||||
return parsed, raw
|
||||
except urllib.error.URLError as e:
|
||||
return None, f"URL_ERROR: {e}"
|
||||
except TimeoutError:
|
||||
return None, "TIMEOUT"
|
||||
except json.JSONDecodeError as e:
|
||||
return None, f"JSON_ERROR: {e}"
|
||||
except Exception as e:
|
||||
return None, f"ERROR: {type(e).__name__}: {e}"
|
||||
|
||||
|
||||
def flatten_value(v):
|
||||
if isinstance(v, str):
|
||||
return v.lower().strip()
|
||||
elif isinstance(v, dict):
|
||||
return json.dumps(v, sort_keys=True).lower()
|
||||
elif isinstance(v, list):
|
||||
return json.dumps(sorted([flatten_value(i) for i in v]))
|
||||
else:
|
||||
return str(v).lower().strip()
|
||||
|
||||
|
||||
def normalize_extraction(extracted):
|
||||
if extracted is None:
|
||||
return None
|
||||
normalized = {}
|
||||
expected_fields = ["people", "organizations", "locations", "dates", "document_type"]
|
||||
for key in expected_fields:
|
||||
val = extracted.get(key, [] if key != "document_type" else "")
|
||||
if isinstance(val, list):
|
||||
normalized[key] = sorted([flatten_value(v) for v in val])
|
||||
else:
|
||||
normalized[key] = flatten_value(val)
|
||||
return normalized
|
||||
|
||||
|
||||
def extractions_consistent(extractions):
|
||||
if any(e is None for e in extractions):
|
||||
return False
|
||||
normalized = [normalize_extraction(e) for e in extractions]
|
||||
if any(n is None for n in normalized):
|
||||
return False
|
||||
return all(n == normalized[0] for n in normalized[1:])
|
||||
|
||||
|
||||
def content_hash(text):
|
||||
return hashlib.md5(text.encode()).hexdigest()[:8]
|
||||
|
||||
|
||||
def main():
|
||||
print(f"\nBirdAI Consistency Test")
|
||||
print(f"Model: {MODEL} | Passes: {PASSES} | Sample: {SAMPLE_SIZE} docs")
|
||||
print(f"Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
print(f"Results: {RESULTS_FILE}")
|
||||
print("-" * 60)
|
||||
|
||||
docs = get_sample_documents()
|
||||
print(f"Loaded {len(docs)} documents from pgvector\n")
|
||||
|
||||
results = {
|
||||
"meta": {
|
||||
"model": MODEL,
|
||||
"passes": PASSES,
|
||||
"sample_size": len(docs),
|
||||
"started": datetime.now().isoformat(),
|
||||
"completed": None
|
||||
},
|
||||
"documents": [],
|
||||
"summary": {}
|
||||
}
|
||||
|
||||
consistent_count = 0
|
||||
failed_count = 0
|
||||
timeout_count = 0
|
||||
|
||||
for i, doc in enumerate(docs):
|
||||
doc_id = doc["id"]
|
||||
content = doc["document"]
|
||||
source = doc.get("source", "unknown")
|
||||
chash = content_hash(content)
|
||||
|
||||
print(f"[{i+1:02d}/{len(docs)}] {source[:50]:<50} hash:{chash}", end=" ", flush=True)
|
||||
|
||||
passes = []
|
||||
pass_times = []
|
||||
raw_outputs = []
|
||||
|
||||
for p in range(PASSES):
|
||||
t_start = time.time()
|
||||
extracted, raw = run_extraction(content)
|
||||
t_end = time.time()
|
||||
passes.append(extracted)
|
||||
pass_times.append(round(t_end - t_start, 1))
|
||||
raw_outputs.append(raw[:200] if raw else "")
|
||||
|
||||
consistent = extractions_consistent(passes)
|
||||
any_timeout = any("TIMEOUT" in str(r) for r in raw_outputs)
|
||||
any_failed = any(p is None for p in passes)
|
||||
|
||||
if any_timeout:
|
||||
timeout_count += 1
|
||||
status = "TIMEOUT"
|
||||
elif any_failed:
|
||||
failed_count += 1
|
||||
status = "FAILED"
|
||||
elif consistent:
|
||||
consistent_count += 1
|
||||
status = "CONSISTENT"
|
||||
else:
|
||||
status = "INCONSISTENT"
|
||||
|
||||
print(f"→ {status} ({'/'.join(str(t) for t in pass_times)}s)")
|
||||
|
||||
try:
|
||||
sample_extraction = normalize_extraction(passes[0]) if passes[0] else None
|
||||
except Exception:
|
||||
sample_extraction = None
|
||||
|
||||
results["documents"].append({
|
||||
"id": doc_id,
|
||||
"source": source,
|
||||
"content_hash": chash,
|
||||
"content_length": len(content),
|
||||
"status": status,
|
||||
"consistent": consistent,
|
||||
"pass_times_seconds": pass_times,
|
||||
"extraction_sample": sample_extraction,
|
||||
"raw_samples": raw_outputs
|
||||
})
|
||||
|
||||
with open(RESULTS_FILE, "w") as f:
|
||||
json.dump(results, f, indent=2, default=str)
|
||||
|
||||
total = len(docs)
|
||||
completed_at = datetime.now().isoformat()
|
||||
results["meta"]["completed"] = completed_at
|
||||
|
||||
summary = {
|
||||
"total": total,
|
||||
"consistent": consistent_count,
|
||||
"inconsistent": total - consistent_count - failed_count - timeout_count,
|
||||
"failed": failed_count,
|
||||
"timeout": timeout_count,
|
||||
"consistency_rate": round(consistent_count / total * 100, 1),
|
||||
"cascade_viable": consistent_count / total >= 0.5
|
||||
}
|
||||
results["summary"] = summary
|
||||
|
||||
with open(RESULTS_FILE, "w") as f:
|
||||
json.dump(results, f, indent=2, default=str)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"RESULTS")
|
||||
print(f" Consistent: {consistent_count}/{total} ({summary['consistency_rate']}%)")
|
||||
print(f" Inconsistent: {summary['inconsistent']}")
|
||||
print(f" Failed/Timeout: {failed_count + timeout_count}")
|
||||
print(f" Cascade viable: {'YES' if summary['cascade_viable'] else 'NO — reconsider architecture'}")
|
||||
print(f" Completed: {completed_at}")
|
||||
print(f" Full results: {RESULTS_FILE}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user