Files
aaronAI/scripts/cascade_test.py
T

486 lines
18 KiB
Python

#!/usr/bin/env python3
"""
Cascade Test — Nodes-vs-Edges Experiment
Tests whether splitting graph extraction into "local drafts entity candidates,
API verifies + draws edges" reduces total API cost vs single-shot full
extraction, while producing a comparable graph.
Two conditions per document:
A — Baseline: single Claude Haiku call, full extraction
B — Cascade: Mistral lists entity candidates, then Haiku does verify+edges
Both conditions:
- See the full document (parity-respecting)
- Use open entity type vocabulary (no fixed schema)
- Use natural-language predicates (no constrained relations)
- Same target output schema, same temperature
Sample: 20 docs from briefing_test_v2_results.json, stratified by char length.
Reports API cost only. Local Mistral time is recorded but not monetized
(ran on the VPS, no per-token API charge).
Outputs: ~/aaronai/experiments/cascade_test_results.json
"""
import json
import os
import re
import statistics
import sys
import time
from datetime import datetime, timezone
from pathlib import Path
import anthropic
import psycopg2
import requests
from dotenv import load_dotenv
load_dotenv(Path.home() / "aaronai" / ".env")
V2_FILE = Path.home() / "aaronai" / "briefing_test_v2_results.json"
OUTPUT_FILE = Path.home() / "aaronai" / "experiments" / "cascade_test_results.json"
HAIKU_MODEL = "claude-haiku-4-5-20251001"
HAIKU_MAX_TOKENS = 4096
HAIKU_TEMPERATURE = 0.0
OLLAMA_URL = "http://localhost:11434/api/generate"
LOCAL_MODEL = "mistral"
LOCAL_TIMEOUT = 120
MAX_DOC_CHARS = 8000
# Verified pricing 2026-04-28 against Anthropic docs
HAIKU_IN_PER_M = 1.0
HAIKU_OUT_PER_M = 5.0
CONDITION_A_PROMPT = """Extract a knowledge graph from the document below.
Return ONLY valid JSON with this exact schema:
{
"entities": [
{"name": string, "type": string}
],
"edges": [
{"subject": string, "predicate": string, "object": string}
]
}
Entity types: use whatever fits the entity. Do not constrain yourself to a fixed list.
Edge predicates: natural language phrases that capture the actual relationship the document states or implies.
Extract every entity and every relationship the document states or strongly implies. Both subject and object in every edge must appear in entities. JSON only, no commentary, no markdown fences.
DOCUMENT:
"""
LOCAL_PROMPT = """List every named entity that appears in the document below — every person, organization, place, project, document, material, technique, date, event, or other named thing.
Return ONLY valid JSON:
{
"candidates": [string]
}
Just names. No types, no relationships. JSON only.
DOCUMENT:
"""
CONDITION_B_API_PROMPT_WITH_DRAFT = """Extract a knowledge graph from the document below.
A local model has identified entity candidates that may help orient your reading. Treat the candidates as a hint, not as truth — verify each candidate appears in the document, ignore any that do not, and add any entities the candidates missed.
Return ONLY valid JSON with this exact schema:
{
"entities": [
{"name": string, "type": string}
],
"edges": [
{"subject": string, "predicate": string, "object": string}
]
}
Entity types: use whatever fits. Edge predicates: natural language phrases capturing the actual relationship. Both subject and object in every edge must appear in entities. Extract every entity and every relationship the document states or strongly implies. JSON only, no commentary, no markdown fences.
ENTITY CANDIDATES FROM LOCAL MODEL:
{local_draft}
DOCUMENT:
"""
def strip_json_fences(text):
if not text:
return ""
t = text.strip()
t = re.sub(r"^```(?:json)?\s*", "", t)
t = re.sub(r"\s*```$", "", t)
return t.strip()
def fetch_document_text(pg_conn, source):
cur = pg_conn.cursor()
cur.execute(
"SELECT document FROM embeddings WHERE source = %s ORDER BY id",
(source,),
)
rows = cur.fetchall()
cur.close()
if not rows:
return None, 0
full = "\n\n".join(r[0] for r in rows)
return full[:MAX_DOC_CHARS], len(full)
def call_haiku(client, prompt_text):
t0 = time.time()
resp = client.messages.create(
model=HAIKU_MODEL,
max_tokens=HAIKU_MAX_TOKENS,
temperature=HAIKU_TEMPERATURE,
messages=[{"role": "user", "content": prompt_text}],
)
return {
"input_tokens": resp.usage.input_tokens,
"output_tokens": resp.usage.output_tokens,
"latency_s": round(time.time() - t0, 2),
"response_text": resp.content[0].text if resp.content else "",
"stop_reason": resp.stop_reason,
}
def call_local(document_text):
t0 = time.time()
try:
resp = requests.post(
OLLAMA_URL,
json={
"model": LOCAL_MODEL,
"prompt": LOCAL_PROMPT + document_text,
"stream": False,
"format": "json",
"options": {"num_predict": 1024, "temperature": 0, "num_ctx": 8192},
},
timeout=LOCAL_TIMEOUT,
)
resp.raise_for_status()
return {
"response": resp.json().get("response", ""),
"latency_s": round(time.time() - t0, 2),
}
except Exception as e:
return {"error": str(e), "latency_s": round(time.time() - t0, 2)}
def parse_graph(raw):
cleaned = strip_json_fences(raw)
if not cleaned:
return None, None
try:
data = json.loads(cleaned)
except json.JSONDecodeError:
return None, None
if not isinstance(data, dict):
return None, None
ents = data.get("entities")
edges = data.get("edges")
if isinstance(ents, list) and isinstance(edges, list):
return len(ents), len(edges)
return None, None
def parse_candidates(raw):
cleaned = strip_json_fences(raw)
if not cleaned:
return None
try:
data = json.loads(cleaned)
except json.JSONDecodeError:
return None
if not isinstance(data, dict):
return None
cands = data.get("candidates")
if isinstance(cands, list):
return [str(c).strip() for c in cands if c]
return None
def stratify(docs):
"""Pick 5 small / 10 medium / 5 large by character length, in file order."""
sized = [(d, d["content_length"]) for d in docs]
small = [d for d, n in sized if n < 1000]
medium = [d for d, n in sized if 1000 <= n < 5000]
large = [d for d, n in sized if n >= 5000]
return small[:5] + medium[:10] + large[:5]
def main():
api_key = os.environ.get("ANTHROPIC_API_KEY")
pg_dsn = os.environ.get("PG_DSN")
if not api_key or not pg_dsn:
print("ERROR: ANTHROPIC_API_KEY or PG_DSN not set", file=sys.stderr)
sys.exit(1)
if not V2_FILE.exists():
print(f"ERROR: {V2_FILE} not found", file=sys.stderr)
sys.exit(1)
with open(V2_FILE) as f:
v2 = json.load(f)
docs_meta = [d for d in v2["documents"] if d.get("status") == "SUCCESS"]
sample = stratify(docs_meta)
print(f"Sample: {len(sample)} docs (stratified by char length, file order)")
for d in sample:
print(f" [{d['content_length']:>6}c] {d['source'][:60]}")
print(f"Haiku model: {HAIKU_MODEL} temp={HAIKU_TEMPERATURE} max_tokens={HAIKU_MAX_TOKENS}")
print(f"Local model: {LOCAL_MODEL}")
print()
client = anthropic.Anthropic(api_key=api_key)
pg_conn = psycopg2.connect(pg_dsn)
results = []
started_at = datetime.now(timezone.utc).isoformat()
t_total = time.time()
for i, doc_meta in enumerate(sample, 1):
source = doc_meta["source"]
doc_text, original_len = fetch_document_text(pg_conn, source)
if not doc_text:
print(f"[{i:02d}/{len(sample)}] {source[:60]} — SKIP (not in pgvector)")
results.append({"source": source, "skipped": "not_in_pgvector"})
continue
sent_len = len(doc_text)
truncated = original_len > sent_len
size_bucket = (
"small" if sent_len < 1000
else "medium" if sent_len < 5000
else "large"
)
trunc_marker = "*" if truncated else " "
print(f"[{i:02d}/{len(sample)}] [{size_bucket:6s}] [{sent_len:>5}c{trunc_marker}] {source[:55]}", flush=True)
# Condition A
try:
a = call_haiku(client, CONDITION_A_PROMPT + doc_text)
a_ents, a_edges = parse_graph(a["response_text"])
print(f" A: in={a['input_tokens']} out={a['output_tokens']} "
f"ents={a_ents} edges={a_edges} stop={a['stop_reason']} t={a['latency_s']}s",
flush=True)
except Exception as e:
print(f" A FAILED: {e}", flush=True)
a = {"error": str(e)}
a_ents = a_edges = None
# Condition B local pass
local_result = call_local(doc_text)
if "error" in local_result:
print(f" B local FAILED: {local_result['error']} — skipping doc", flush=True)
results.append({
"source": source,
"size_bucket": size_bucket,
"doc_chars_original": original_len,
"doc_chars_sent": sent_len,
"truncated": truncated,
"condition_a": {
"input_tokens": a.get("input_tokens"),
"output_tokens": a.get("output_tokens"),
"latency_s": a.get("latency_s"),
"entity_count": a_ents,
"edge_count": a_edges,
"stop_reason": a.get("stop_reason"),
"response_text": a.get("response_text", "")[:4000],
"error": a.get("error"),
},
"condition_b": {
"skipped": "local_model_failed",
"local_error": local_result["error"],
"local_latency_s": local_result.get("latency_s"),
},
})
continue
local_raw = local_result["response"]
cands = parse_candidates(local_raw)
local_candidates = cands or []
print(f" B local: t={local_result['latency_s']}s candidates={len(local_candidates)}",
flush=True)
if not local_candidates:
print(f" B local: empty draft — skipping API call to avoid asymmetric test", flush=True)
results.append({
"source": source,
"size_bucket": size_bucket,
"doc_chars_original": original_len,
"doc_chars_sent": sent_len,
"truncated": truncated,
"condition_a": {
"input_tokens": a.get("input_tokens"),
"output_tokens": a.get("output_tokens"),
"latency_s": a.get("latency_s"),
"entity_count": a_ents,
"edge_count": a_edges,
"stop_reason": a.get("stop_reason"),
"response_text": a.get("response_text", "")[:4000],
"error": a.get("error"),
},
"condition_b": {
"skipped": "local_draft_empty",
"local_latency_s": local_result.get("latency_s"),
"local_raw": local_raw[:1000],
},
})
continue
local_draft_str = "\n".join(f"- {c}" for c in local_candidates)
b_prompt = CONDITION_B_API_PROMPT_WITH_DRAFT.replace("{local_draft}", local_draft_str) + doc_text
try:
b = call_haiku(client, b_prompt)
b_ents, b_edges = parse_graph(b["response_text"])
print(f" B api: in={b['input_tokens']} out={b['output_tokens']} "
f"ents={b_ents} edges={b_edges} stop={b['stop_reason']} t={b['latency_s']}s",
flush=True)
except Exception as e:
print(f" B api FAILED: {e}", flush=True)
b = {"error": str(e)}
b_ents = b_edges = None
if "input_tokens" in a and "input_tokens" in b:
in_pct = (b["input_tokens"] - a["input_tokens"]) / a["input_tokens"] * 100 if a["input_tokens"] else 0.0
out_pct = (b["output_tokens"] - a["output_tokens"]) / a["output_tokens"] * 100 if a["output_tokens"] else 0.0
edge_pct_str = "n/a"
if a_edges and b_edges is not None and a_edges > 0:
edge_pct_str = f"{(b_edges - a_edges) / a_edges * 100:+.1f}%"
print(f" Δ input={in_pct:+.1f}% output={out_pct:+.1f}% edges={edge_pct_str}", flush=True)
results.append({
"source": source,
"size_bucket": size_bucket,
"doc_chars_original": original_len,
"doc_chars_sent": sent_len,
"truncated": truncated,
"condition_a": {
"input_tokens": a.get("input_tokens"),
"output_tokens": a.get("output_tokens"),
"latency_s": a.get("latency_s"),
"entity_count": a_ents,
"edge_count": a_edges,
"stop_reason": a.get("stop_reason"),
"response_text": a.get("response_text", "")[:4000],
"error": a.get("error"),
},
"condition_b": {
"local_latency_s": local_result.get("latency_s"),
"local_candidates": local_candidates,
"local_raw": local_raw[:1000],
"api_input_tokens": b.get("input_tokens"),
"api_output_tokens": b.get("output_tokens"),
"api_latency_s": b.get("latency_s"),
"entity_count": b_ents,
"edge_count": b_edges,
"stop_reason": b.get("stop_reason"),
"response_text": b.get("response_text", "")[:4000],
"error": b.get("error"),
},
})
pg_conn.close()
total_elapsed = round(time.time() - t_total, 1)
valid = [r for r in results
if r.get("condition_a", {}).get("input_tokens") is not None
and r.get("condition_b", {}).get("api_input_tokens") is not None]
a_in = sum(r["condition_a"]["input_tokens"] for r in valid)
a_out = sum(r["condition_a"]["output_tokens"] for r in valid)
b_in = sum(r["condition_b"]["api_input_tokens"] for r in valid)
b_out = sum(r["condition_b"]["api_output_tokens"] for r in valid)
a_cost = (a_in * HAIKU_IN_PER_M + a_out * HAIKU_OUT_PER_M) / 1_000_000
b_cost = (b_in * HAIKU_IN_PER_M + b_out * HAIKU_OUT_PER_M) / 1_000_000
by_bucket = {}
for bucket in ("small", "medium", "large"):
rows = [r for r in valid if r["size_bucket"] == bucket]
if not rows:
by_bucket[bucket] = None
continue
ai = sum(r["condition_a"]["input_tokens"] for r in rows)
ao = sum(r["condition_a"]["output_tokens"] for r in rows)
bi = sum(r["condition_b"]["api_input_tokens"] for r in rows)
bo = sum(r["condition_b"]["api_output_tokens"] for r in rows)
ae = [r["condition_a"]["edge_count"] for r in rows if r["condition_a"]["edge_count"] is not None]
be = [r["condition_b"]["edge_count"] for r in rows if r["condition_b"]["edge_count"] is not None]
by_bucket[bucket] = {
"n": len(rows),
"a_input_tokens": ai,
"a_output_tokens": ao,
"b_input_tokens": bi,
"b_output_tokens": bo,
"input_delta_pct": round((bi - ai) / ai * 100, 2) if ai else None,
"output_delta_pct": round((bo - ao) / ao * 100, 2) if ao else None,
"a_avg_edges": round(statistics.mean(ae), 1) if ae else None,
"b_avg_edges": round(statistics.mean(be), 1) if be else None,
}
summary = {
"experiment": "cascade_test",
"title": "Nodes-vs-Edges Cascade Experiment",
"started_at": started_at,
"completed_at": datetime.now(timezone.utc).isoformat(),
"haiku_model": HAIKU_MODEL,
"haiku_temperature": HAIKU_TEMPERATURE,
"haiku_max_tokens": HAIKU_MAX_TOKENS,
"local_model": LOCAL_MODEL,
"max_doc_chars": MAX_DOC_CHARS,
"n_documents": len(sample),
"n_valid_pairs": len(valid),
"n_skipped": len(sample) - len(valid),
"total_elapsed_s": total_elapsed,
"totals": {
"a_input_tokens": a_in,
"a_output_tokens": a_out,
"b_input_tokens": b_in,
"b_output_tokens": b_out,
"a_cost_usd": round(a_cost, 4),
"b_cost_usd": round(b_cost, 4),
"cost_delta_usd": round(b_cost - a_cost, 4),
"cost_delta_pct": round((b_cost - a_cost) / a_cost * 100, 2) if a_cost else None,
"note": "API cost only — local Mistral runtime on VPS not monetized",
},
"by_size_bucket": by_bucket,
"results": results,
}
OUTPUT_FILE.parent.mkdir(parents=True, exist_ok=True)
with open(OUTPUT_FILE, "w") as f:
json.dump(summary, f, indent=2)
print()
print("=" * 60)
print(f"DONE — {len(valid)}/{len(sample)} valid pairs in {total_elapsed}s")
print(f"A total cost: ${a_cost:.4f} (in={a_in} out={a_out})")
print(f"B total cost: ${b_cost:.4f} (in={b_in} out={b_out})")
delta_pct = summary['totals']['cost_delta_pct']
if delta_pct is not None:
verdict = "B cheaper" if delta_pct < 0 else "B more expensive"
print(f"Cost delta: {delta_pct:+.2f}% ({verdict})")
print()
print("By size bucket:")
for bucket, stats in by_bucket.items():
if stats:
print(f" {bucket:6s} (n={stats['n']}): "
f"in {stats['input_delta_pct']:+.1f}% "
f"out {stats['output_delta_pct']:+.1f}% "
f"edges A={stats['a_avg_edges']} B={stats['b_avg_edges']}")
print()
print(f"NOTE: API cost only. Local Mistral runtime is not monetized.")
print(f"Results: {OUTPUT_FILE}")
if __name__ == "__main__":
main()