258 lines
8.9 KiB
Python
258 lines
8.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Experiment 005 — Actual API Token Measurement
|
|
|
|
Measures input token reduction from prepending v2 briefing vs raw document
|
|
on Claude Haiku, validating the 42.0% modeled estimate from Experiment 002b.
|
|
|
|
Outputs: ~/aaronai/experiments/token_measurement_results.json
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import statistics
|
|
import sys
|
|
import time
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
|
|
import anthropic
|
|
import psycopg2
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv(Path.home() / "aaronai" / ".env")
|
|
|
|
INPUT_FILE = Path.home() / "aaronai" / "briefing_test_v2_results.json"
|
|
OUTPUT_FILE = Path.home() / "aaronai" / "experiments" / "token_measurement_results.json"
|
|
MODEL = "claude-haiku-4-5-20251001"
|
|
MAX_TOKENS = 1024
|
|
|
|
EXTRACTION_PROMPT = (
|
|
"Extract entities and their relationships from the document below. "
|
|
"Return ONLY valid JSON with this schema:\n"
|
|
"{\n"
|
|
' "people": [string],\n'
|
|
' "organizations": [string],\n'
|
|
' "locations": [string],\n'
|
|
' "dates": [string],\n'
|
|
' "relationships": [{"subject": string, "predicate": string, "object": string}]\n'
|
|
"}\n"
|
|
"No prose, no markdown fences, no commentary. JSON only."
|
|
)
|
|
|
|
|
|
def fetch_document_text(pg_conn, source):
|
|
"""Reconstruct the document by concatenating its chunks from pgvector."""
|
|
cur = pg_conn.cursor()
|
|
cur.execute(
|
|
"SELECT document FROM embeddings WHERE source = %s ORDER BY id",
|
|
(source,),
|
|
)
|
|
rows = cur.fetchall()
|
|
cur.close()
|
|
if not rows:
|
|
return None
|
|
return "\n\n".join(r[0] for r in rows)
|
|
|
|
|
|
def build_raw_message(document_text):
|
|
return f"{EXTRACTION_PROMPT}\n\nDOCUMENT:\n{document_text}"
|
|
|
|
|
|
def build_briefed_message(briefing, document_text):
|
|
briefing_str = json.dumps(briefing, indent=2)
|
|
return (
|
|
f"{EXTRACTION_PROMPT}\n\n"
|
|
f"BRIEFING (pre-analysis from local model — use to orient):\n{briefing_str}\n\n"
|
|
f"DOCUMENT:\n{document_text}"
|
|
)
|
|
|
|
|
|
def call_haiku(client, message_text):
|
|
t0 = time.time()
|
|
resp = client.messages.create(
|
|
model=MODEL,
|
|
max_tokens=MAX_TOKENS,
|
|
messages=[{"role": "user", "content": message_text}],
|
|
)
|
|
return {
|
|
"input_tokens": resp.usage.input_tokens,
|
|
"output_tokens": resp.usage.output_tokens,
|
|
"latency_s": round(time.time() - t0, 2),
|
|
"response_text": resp.content[0].text if resp.content else "",
|
|
"stop_reason": resp.stop_reason,
|
|
}
|
|
|
|
|
|
def ci_95(values):
|
|
if len(values) < 2:
|
|
return (statistics.mean(values) if values else 0.0, 0.0)
|
|
mean = statistics.mean(values)
|
|
half = 1.96 * statistics.stdev(values) / (len(values) ** 0.5)
|
|
return (mean, half)
|
|
|
|
|
|
def main():
|
|
if not INPUT_FILE.exists():
|
|
print(f"ERROR: {INPUT_FILE} not found", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
api_key = os.environ.get("ANTHROPIC_API_KEY")
|
|
if not api_key:
|
|
print("ERROR: ANTHROPIC_API_KEY not set", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
pg_dsn = os.environ.get("PG_DSN")
|
|
if not pg_dsn:
|
|
print("ERROR: PG_DSN not set", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
client = anthropic.Anthropic(api_key=api_key)
|
|
pg_conn = psycopg2.connect(pg_dsn)
|
|
|
|
with open(INPUT_FILE) as f:
|
|
v2_data = json.load(f)
|
|
|
|
docs_meta = [
|
|
d for d in v2_data["documents"]
|
|
if d.get("status") == "SUCCESS"
|
|
and d.get("briefing")
|
|
]
|
|
|
|
print(f"Loaded {len(docs_meta)} successful briefings from {INPUT_FILE.name}")
|
|
print(f"Model: {MODEL}")
|
|
print(f"Calls planned: up to {len(docs_meta) * 2}\n")
|
|
|
|
results = []
|
|
started_at = datetime.now(timezone.utc).isoformat()
|
|
t_total = time.time()
|
|
|
|
for i, doc in enumerate(docs_meta, 1):
|
|
source = doc["source"]
|
|
briefing = doc["briefing"]
|
|
|
|
document_text = fetch_document_text(pg_conn, source)
|
|
if not document_text:
|
|
print(f"[{i:02d}/{len(docs_meta)}] {source[:60]} -- SKIP (not in pgvector)")
|
|
results.append({"source": source, "skipped": "not_in_pgvector"})
|
|
continue
|
|
|
|
print(f"[{i:02d}/{len(docs_meta)}] {source[:60]}")
|
|
|
|
try:
|
|
raw_result = call_haiku(client, build_raw_message(document_text))
|
|
except Exception as e:
|
|
print(f" RAW FAILED: {e}")
|
|
raw_result = {"error": str(e)}
|
|
|
|
try:
|
|
briefed_result = call_haiku(client, build_briefed_message(briefing, document_text))
|
|
except Exception as e:
|
|
print(f" BRIEFED FAILED: {e}")
|
|
briefed_result = {"error": str(e)}
|
|
|
|
delta = None
|
|
if "input_tokens" in raw_result and "input_tokens" in briefed_result:
|
|
raw_in = raw_result["input_tokens"]
|
|
briefed_in = briefed_result["input_tokens"]
|
|
raw_out = raw_result["output_tokens"]
|
|
briefed_out = briefed_result["output_tokens"]
|
|
input_red = (raw_in - briefed_in) / raw_in * 100 if raw_in else 0.0
|
|
output_delta = (briefed_out - raw_out) / raw_out * 100 if raw_out else 0.0
|
|
delta = {
|
|
"input_reduction_pct": round(input_red, 2),
|
|
"output_delta_pct": round(output_delta, 2),
|
|
"raw_input_tokens": raw_in,
|
|
"briefed_input_tokens": briefed_in,
|
|
"raw_output_tokens": raw_out,
|
|
"briefed_output_tokens": briefed_out,
|
|
}
|
|
print(
|
|
f" in: {raw_in} -> {briefed_in} ({input_red:+.1f}%) | "
|
|
f"out: {raw_out} -> {briefed_out}"
|
|
)
|
|
|
|
results.append({
|
|
"source": source,
|
|
"raw": raw_result,
|
|
"briefed": briefed_result,
|
|
"delta": delta,
|
|
})
|
|
|
|
pg_conn.close()
|
|
total_elapsed = round(time.time() - t_total, 1)
|
|
|
|
valid = [r for r in results if r.get("delta") is not None]
|
|
skipped = [r for r in results if r.get("skipped")]
|
|
reductions = [r["delta"]["input_reduction_pct"] for r in valid]
|
|
output_deltas = [r["delta"]["output_delta_pct"] for r in valid]
|
|
raw_in_total = sum(r["delta"]["raw_input_tokens"] for r in valid)
|
|
briefed_in_total = sum(r["delta"]["briefed_input_tokens"] for r in valid)
|
|
raw_out_total = sum(r["delta"]["raw_output_tokens"] for r in valid)
|
|
briefed_out_total = sum(r["delta"]["briefed_output_tokens"] for r in valid)
|
|
|
|
HAIKU_IN = 1.0
|
|
HAIKU_OUT = 5.0
|
|
raw_cost = (raw_in_total * HAIKU_IN + raw_out_total * HAIKU_OUT) / 1_000_000
|
|
briefed_cost = (briefed_in_total * HAIKU_IN + briefed_out_total * HAIKU_OUT) / 1_000_000
|
|
|
|
mean_red, ci_half = ci_95(reductions)
|
|
mean_out_delta, _ = ci_95(output_deltas)
|
|
|
|
summary = {
|
|
"experiment": "005",
|
|
"title": "Actual API Token Measurement",
|
|
"started_at": started_at,
|
|
"completed_at": datetime.now(timezone.utc).isoformat(),
|
|
"model": MODEL,
|
|
"extraction_prompt": EXTRACTION_PROMPT,
|
|
"n_documents_attempted": len(docs_meta),
|
|
"n_skipped_not_in_pgvector": len(skipped),
|
|
"n_valid_pairs": len(valid),
|
|
"n_failed": len(docs_meta) - len(valid) - len(skipped),
|
|
"total_elapsed_s": total_elapsed,
|
|
"input_token_reduction": {
|
|
"mean_pct": round(mean_red, 2),
|
|
"ci_95_half_width_pct": round(ci_half, 2),
|
|
"median_pct": round(statistics.median(reductions), 2) if reductions else None,
|
|
"min_pct": round(min(reductions), 2) if reductions else None,
|
|
"max_pct": round(max(reductions), 2) if reductions else None,
|
|
"stdev_pct": round(statistics.stdev(reductions), 2) if len(reductions) > 1 else 0.0,
|
|
},
|
|
"output_token_delta": {"mean_pct": round(mean_out_delta, 2)},
|
|
"totals": {
|
|
"raw_input_tokens": raw_in_total,
|
|
"briefed_input_tokens": briefed_in_total,
|
|
"raw_output_tokens": raw_out_total,
|
|
"briefed_output_tokens": briefed_out_total,
|
|
"raw_cost_usd": round(raw_cost, 4),
|
|
"briefed_cost_usd": round(briefed_cost, 4),
|
|
"savings_usd": round(raw_cost - briefed_cost, 4),
|
|
},
|
|
"comparison_to_v2_estimate": {
|
|
"v2_modeled_reduction_pct": 42.0,
|
|
"measured_mean_reduction_pct": round(mean_red, 2),
|
|
"delta_pct_points": round(mean_red - 42.0, 2),
|
|
},
|
|
"results": results,
|
|
}
|
|
|
|
OUTPUT_FILE.parent.mkdir(parents=True, exist_ok=True)
|
|
with open(OUTPUT_FILE, "w") as f:
|
|
json.dump(summary, f, indent=2)
|
|
|
|
print()
|
|
print("=" * 60)
|
|
print(f"DONE — {len(valid)}/{len(docs_meta)} valid pairs in {total_elapsed}s")
|
|
if skipped:
|
|
print(f"Skipped (not in pgvector): {len(skipped)}")
|
|
print(f"Mean input token reduction: {mean_red:.2f}% +/- {ci_half:.2f}% (95% CI)")
|
|
print(f"V2 modeled estimate: 42.0% | delta: {mean_red - 42.0:+.2f} pts")
|
|
print(f"Mean output token delta: {mean_out_delta:+.2f}%")
|
|
print(f"Total cost: ${raw_cost + briefed_cost:.4f}")
|
|
print(f"Results: {OUTPUT_FILE}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|