249 lines
7.4 KiB
Python
249 lines
7.4 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
BirdAI Cascaded Extraction — Consistency Test
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import urllib.request
|
|
import urllib.error
|
|
import psycopg2
|
|
import psycopg2.extras
|
|
import hashlib
|
|
import time
|
|
from datetime import datetime
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv(os.path.expanduser("~/aaronai/.env"))
|
|
|
|
PG_DSN = os.getenv("PG_DSN")
|
|
RESULTS_FILE = os.path.expanduser("~/aaronai/consistency_test_results.json")
|
|
MODEL = "mistral"
|
|
PASSES = 3
|
|
SAMPLE_SIZE = 50
|
|
OLLAMA_URL = "http://localhost:11434/api/generate"
|
|
|
|
EXTRACTION_PROMPT = """Extract named entities from this text. Return JSON only, no explanation, no prose.
|
|
Use exactly these fields (omit any field you are uncertain about, use empty list if none found):
|
|
{
|
|
"people": [],
|
|
"organizations": [],
|
|
"locations": [],
|
|
"dates": [],
|
|
"document_type": ""
|
|
}
|
|
Rules:
|
|
- Every value in people, organizations, locations, dates must be a plain string
|
|
- document_type must be a plain string
|
|
- No nested objects, no nested lists
|
|
- Only include entities you are certain about
|
|
- If uncertain about anything, omit it
|
|
Text: """
|
|
|
|
|
|
def get_sample_documents():
|
|
conn = psycopg2.connect(PG_DSN)
|
|
cur = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
|
|
cur.execute("""
|
|
SELECT id, document, source, created_at
|
|
FROM embeddings
|
|
WHERE length(document) > 100
|
|
AND length(document) < 3000
|
|
ORDER BY random()
|
|
LIMIT %s
|
|
""", (SAMPLE_SIZE,))
|
|
docs = cur.fetchall()
|
|
cur.close()
|
|
conn.close()
|
|
return docs
|
|
|
|
|
|
def run_extraction(text):
|
|
prompt = EXTRACTION_PROMPT + text[:1500]
|
|
payload = json.dumps({
|
|
"model": MODEL,
|
|
"prompt": prompt,
|
|
"stream": False
|
|
}).encode()
|
|
try:
|
|
req = urllib.request.Request(
|
|
OLLAMA_URL,
|
|
data=payload,
|
|
headers={"Content-Type": "application/json"}
|
|
)
|
|
with urllib.request.urlopen(req, timeout=180) as resp:
|
|
result = json.loads(resp.read().decode())
|
|
raw = result.get("response", "").strip()
|
|
start = raw.find("{")
|
|
end = raw.rfind("}") + 1
|
|
if start == -1 or end == 0:
|
|
return None, f"NO_JSON: {raw[:100]}"
|
|
json_str = raw[start:end]
|
|
parsed = json.loads(json_str)
|
|
if not isinstance(parsed, dict):
|
|
return None, f"NOT_DICT: {json_str[:100]}"
|
|
return parsed, raw
|
|
except urllib.error.URLError as e:
|
|
return None, f"URL_ERROR: {e}"
|
|
except TimeoutError:
|
|
return None, "TIMEOUT"
|
|
except json.JSONDecodeError as e:
|
|
return None, f"JSON_ERROR: {e}"
|
|
except Exception as e:
|
|
return None, f"ERROR: {type(e).__name__}: {e}"
|
|
|
|
|
|
def flatten_value(v):
|
|
if isinstance(v, str):
|
|
return v.lower().strip()
|
|
elif isinstance(v, dict):
|
|
return json.dumps(v, sort_keys=True).lower()
|
|
elif isinstance(v, list):
|
|
return json.dumps(sorted([flatten_value(i) for i in v]))
|
|
else:
|
|
return str(v).lower().strip()
|
|
|
|
|
|
def normalize_extraction(extracted):
|
|
if extracted is None:
|
|
return None
|
|
normalized = {}
|
|
expected_fields = ["people", "organizations", "locations", "dates", "document_type"]
|
|
for key in expected_fields:
|
|
val = extracted.get(key, [] if key != "document_type" else "")
|
|
if isinstance(val, list):
|
|
normalized[key] = sorted([flatten_value(v) for v in val])
|
|
else:
|
|
normalized[key] = flatten_value(val)
|
|
return normalized
|
|
|
|
|
|
def extractions_consistent(extractions):
|
|
if any(e is None for e in extractions):
|
|
return False
|
|
normalized = [normalize_extraction(e) for e in extractions]
|
|
if any(n is None for n in normalized):
|
|
return False
|
|
return all(n == normalized[0] for n in normalized[1:])
|
|
|
|
|
|
def content_hash(text):
|
|
return hashlib.md5(text.encode()).hexdigest()[:8]
|
|
|
|
|
|
def main():
|
|
print(f"\nBirdAI Consistency Test")
|
|
print(f"Model: {MODEL} | Passes: {PASSES} | Sample: {SAMPLE_SIZE} docs")
|
|
print(f"Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
|
print(f"Results: {RESULTS_FILE}")
|
|
print("-" * 60)
|
|
|
|
docs = get_sample_documents()
|
|
print(f"Loaded {len(docs)} documents from pgvector\n")
|
|
|
|
results = {
|
|
"meta": {
|
|
"model": MODEL,
|
|
"passes": PASSES,
|
|
"sample_size": len(docs),
|
|
"started": datetime.now().isoformat(),
|
|
"completed": None
|
|
},
|
|
"documents": [],
|
|
"summary": {}
|
|
}
|
|
|
|
consistent_count = 0
|
|
failed_count = 0
|
|
timeout_count = 0
|
|
|
|
for i, doc in enumerate(docs):
|
|
doc_id = doc["id"]
|
|
content = doc["document"]
|
|
source = doc.get("source", "unknown")
|
|
chash = content_hash(content)
|
|
|
|
print(f"[{i+1:02d}/{len(docs)}] {source[:50]:<50} hash:{chash}", end=" ", flush=True)
|
|
|
|
passes = []
|
|
pass_times = []
|
|
raw_outputs = []
|
|
|
|
for p in range(PASSES):
|
|
t_start = time.time()
|
|
extracted, raw = run_extraction(content)
|
|
t_end = time.time()
|
|
passes.append(extracted)
|
|
pass_times.append(round(t_end - t_start, 1))
|
|
raw_outputs.append(raw[:200] if raw else "")
|
|
|
|
consistent = extractions_consistent(passes)
|
|
any_timeout = any("TIMEOUT" in str(r) for r in raw_outputs)
|
|
any_failed = any(p is None for p in passes)
|
|
|
|
if any_timeout:
|
|
timeout_count += 1
|
|
status = "TIMEOUT"
|
|
elif any_failed:
|
|
failed_count += 1
|
|
status = "FAILED"
|
|
elif consistent:
|
|
consistent_count += 1
|
|
status = "CONSISTENT"
|
|
else:
|
|
status = "INCONSISTENT"
|
|
|
|
print(f"→ {status} ({'/'.join(str(t) for t in pass_times)}s)")
|
|
|
|
try:
|
|
sample_extraction = normalize_extraction(passes[0]) if passes[0] else None
|
|
except Exception:
|
|
sample_extraction = None
|
|
|
|
results["documents"].append({
|
|
"id": doc_id,
|
|
"source": source,
|
|
"content_hash": chash,
|
|
"content_length": len(content),
|
|
"status": status,
|
|
"consistent": consistent,
|
|
"pass_times_seconds": pass_times,
|
|
"extraction_sample": sample_extraction,
|
|
"raw_samples": raw_outputs
|
|
})
|
|
|
|
with open(RESULTS_FILE, "w") as f:
|
|
json.dump(results, f, indent=2, default=str)
|
|
|
|
total = len(docs)
|
|
completed_at = datetime.now().isoformat()
|
|
results["meta"]["completed"] = completed_at
|
|
|
|
summary = {
|
|
"total": total,
|
|
"consistent": consistent_count,
|
|
"inconsistent": total - consistent_count - failed_count - timeout_count,
|
|
"failed": failed_count,
|
|
"timeout": timeout_count,
|
|
"consistency_rate": round(consistent_count / total * 100, 1),
|
|
"cascade_viable": consistent_count / total >= 0.5
|
|
}
|
|
results["summary"] = summary
|
|
|
|
with open(RESULTS_FILE, "w") as f:
|
|
json.dump(results, f, indent=2, default=str)
|
|
|
|
print("\n" + "=" * 60)
|
|
print(f"RESULTS")
|
|
print(f" Consistent: {consistent_count}/{total} ({summary['consistency_rate']}%)")
|
|
print(f" Inconsistent: {summary['inconsistent']}")
|
|
print(f" Failed/Timeout: {failed_count + timeout_count}")
|
|
print(f" Cascade viable: {'YES' if summary['cascade_viable'] else 'NO — reconsider architecture'}")
|
|
print(f" Completed: {completed_at}")
|
|
print(f" Full results: {RESULTS_FILE}")
|
|
print("=" * 60)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|