#!/usr/bin/env python3 """ Base-Class Enrichment Test — OOP Framing Experiment Tests whether non-entity metadata from a local model (domain class, structural signals, presence flags, length, summary) can take load off the API without constraining what it extracts. The local model does NOT draft entities. The API still does full extraction. The local model produces metadata that orients the API's reading. Conditions: A — Baseline: single Claude Haiku call, full extraction, no metadata B — Base-class: Mistral metadata + Haiku full extraction with metadata as frame Critical test: B's edge count and predicate diversity must be ≥A's, or close. If B produces fewer edges or less predicate diversity, metadata is acting as constraint and the OOP framing is falsified. Sample: 20 docs from briefing_test_v2_results.json: - 5 small (<1000 chars) - 10 medium (1000-5000 chars) - 5 large (5000-12000 chars, capped at 12K) Outputs: ~/aaronai/experiments/base_class_test_results.json """ import json import os import re import statistics import sys import time from datetime import datetime, timezone from pathlib import Path import anthropic import psycopg2 import requests from dotenv import load_dotenv load_dotenv(Path.home() / "aaronai" / ".env") V2_FILE = Path.home() / "aaronai" / "briefing_test_v2_results.json" OUTPUT_FILE = Path.home() / "aaronai" / "experiments" / "base_class_test_results.json" HAIKU_MODEL = "claude-haiku-4-5-20251001" HAIKU_MAX_TOKENS = 4096 HAIKU_TEMPERATURE = 0.0 OLLAMA_URL = "http://localhost:11434/api/generate" LOCAL_MODEL = "mistral" LOCAL_TIMEOUT = 180 MAX_DOC_CHARS = 12000 HAIKU_IN_PER_M = 1.0 HAIKU_OUT_PER_M = 5.0 CONDITION_A_PROMPT = """Extract a knowledge graph from the document below. Return ONLY valid JSON with this exact schema: { "entities": [ {"name": string, "type": string} ], "edges": [ {"subject": string, "predicate": string, "object": string} ] } Entity types: use whatever fits the entity. Do not constrain yourself to a fixed list. Edge predicates: natural language phrases that capture the actual relationship the document states or implies. Extract every entity and every relationship the document states or strongly implies. Both subject and object in every edge must appear in entities. JSON only, no commentary, no markdown fences. DOCUMENT: """ LOCAL_METADATA_PROMPT = """Analyze the document below and produce metadata describing its surface features. Do NOT extract entities. Do NOT identify content. Only produce structural and surface-level metadata. Return ONLY valid JSON with this exact schema: { "language": "en or other", "char_length": integer, "primary_format": "prose, presentation, list, form, code, or mixed", "structural_signals": { "has_headings": boolean, "has_bullet_lists": boolean, "has_numbered_lists": boolean, "has_tables": boolean, "has_code_blocks": boolean, "has_dates": boolean }, "content_signals": { "has_named_people": boolean, "has_institutional_language": boolean, "has_technical_terminology": boolean, "has_first_person": boolean, "has_quotations": boolean }, "domain_class": "technical, administrative, personal, educational, creative, reference, or mixed", "one_sentence_summary": "string of 25 words or fewer describing what the document is about" } JSON only, no commentary. DOCUMENT: """ CONDITION_B_API_PROMPT = """You are extracting a knowledge graph from a document. The document has been pre-analyzed by a local model and the following metadata is provided as orienting context — not as constraint. Extract every entity and every relationship in the document. Do not limit your extraction to what the metadata suggests; the metadata is here to orient your reading, not to bound it. DOCUMENT METADATA: {metadata_json} Return ONLY valid JSON with this exact schema: { "entities": [ {"name": string, "type": string} ], "edges": [ {"subject": string, "predicate": string, "object": string} ] } Entity types: use whatever fits. Edge predicates: natural language phrases capturing the actual relationship. Both subject and object in every edge must appear in entities. Extract every entity and every relationship the document states or strongly implies. Do not filter for salience. JSON only, no commentary, no markdown fences. DOCUMENT: """ def strip_json_fences(text): if not text: return "" t = text.strip() t = re.sub(r"^```(?:json)?\s*", "", t) t = re.sub(r"\s*```$", "", t) return t.strip() def fetch_document_text(pg_conn, source): cur = pg_conn.cursor() cur.execute( "SELECT document FROM embeddings WHERE source = %s ORDER BY id", (source,), ) rows = cur.fetchall() cur.close() if not rows: return None, 0 full = "\n\n".join(r[0] for r in rows) return full[:MAX_DOC_CHARS], len(full) def call_haiku(client, prompt_text): t0 = time.time() resp = client.messages.create( model=HAIKU_MODEL, max_tokens=HAIKU_MAX_TOKENS, temperature=HAIKU_TEMPERATURE, messages=[{"role": "user", "content": prompt_text}], ) return { "input_tokens": resp.usage.input_tokens, "output_tokens": resp.usage.output_tokens, "latency_s": round(time.time() - t0, 2), "response_text": resp.content[0].text if resp.content else "", "stop_reason": resp.stop_reason, } def call_local_metadata(document_text): t0 = time.time() try: resp = requests.post( OLLAMA_URL, json={ "model": LOCAL_MODEL, "prompt": LOCAL_METADATA_PROMPT + document_text, "stream": False, "format": "json", "options": {"num_predict": 1024, "temperature": 0, "num_ctx": 12288}, }, timeout=LOCAL_TIMEOUT, ) resp.raise_for_status() return { "response": resp.json().get("response", ""), "latency_s": round(time.time() - t0, 2), } except Exception as e: return {"error": str(e), "latency_s": round(time.time() - t0, 2)} def parse_graph_full(raw): """Return (entities_list, edges_list, parsed_ok). Lists for metric computation.""" cleaned = strip_json_fences(raw) if not cleaned: return None, None, False try: data = json.loads(cleaned) except json.JSONDecodeError: return None, None, False if not isinstance(data, dict): return None, None, False ents = data.get("entities") edges = data.get("edges") if isinstance(ents, list) and isinstance(edges, list): return ents, edges, True return None, None, False def parse_metadata(raw): cleaned = strip_json_fences(raw) if not cleaned: return None try: return json.loads(cleaned) except json.JSONDecodeError: return None def graph_metrics(entities, edges): """Compute graph quality metrics. Inputs are lists from parse_graph_full.""" if entities is None or edges is None: return None n_entities = len(entities) n_edges = len(edges) # Predicate diversity predicates = set() for e in edges: if isinstance(e, dict): p = e.get("predicate") if p: predicates.add(str(p).strip().lower()) predicate_diversity = len(predicates) # Entity type diversity types = set() for ent in entities: if isinstance(ent, dict): t = ent.get("type") if t: types.add(str(t).strip().lower()) type_diversity = len(types) # Average degree (edges*2 / entities — each edge touches two nodes) avg_degree = (2 * n_edges / n_entities) if n_entities > 0 else 0.0 # Largest connected component # Build adjacency from edges entity_names = set() for ent in entities: if isinstance(ent, dict): n = ent.get("name") if n: entity_names.add(str(n).strip().lower()) adj = {name: set() for name in entity_names} for e in edges: if not isinstance(e, dict): continue s = str(e.get("subject", "")).strip().lower() o = str(e.get("object", "")).strip().lower() if s in adj and o in adj: adj[s].add(o) adj[o].add(s) # BFS for largest component visited = set() largest = 0 for start in adj: if start in visited: continue component = 0 stack = [start] while stack: node = stack.pop() if node in visited: continue visited.add(node) component += 1 for neighbor in adj[node]: if neighbor not in visited: stack.append(neighbor) if component > largest: largest = component return { "n_entities": n_entities, "n_edges": n_edges, "predicate_diversity": predicate_diversity, "type_diversity": type_diversity, "avg_degree": round(avg_degree, 2), "largest_component": largest, "largest_component_pct": round(100 * largest / n_entities, 1) if n_entities else 0.0, } def stratify(docs): sized = [(d, d["content_length"]) for d in docs] small = [d for d, n in sized if n < 1000] medium = [d for d, n in sized if 1000 <= n < 5000] large = [d for d, n in sized if n >= 5000] return small[:5] + medium[:10] + large[:5] def fmt_metrics(m): if m is None: return "n/a" return (f"e={m['n_entities']} edge={m['n_edges']} " f"pred={m['predicate_diversity']} type={m['type_diversity']} " f"deg={m['avg_degree']} comp={m['largest_component']}/{m['n_entities']}") def main(): api_key = os.environ.get("ANTHROPIC_API_KEY") pg_dsn = os.environ.get("PG_DSN") if not api_key or not pg_dsn: print("ERROR: ANTHROPIC_API_KEY or PG_DSN not set", file=sys.stderr) sys.exit(1) if not V2_FILE.exists(): print(f"ERROR: {V2_FILE} not found", file=sys.stderr) sys.exit(1) with open(V2_FILE) as f: v2 = json.load(f) docs_meta = [d for d in v2["documents"] if d.get("status") == "SUCCESS"] sample = stratify(docs_meta) print(f"Sample: {len(sample)} docs (5s/10m/5l, file order)") print(f"Mistral context: 12288 tokens, doc cap {MAX_DOC_CHARS} chars") print(f"Haiku model: {HAIKU_MODEL} temp={HAIKU_TEMPERATURE}") print(f"Test: base-class metadata as orienting frame, NOT entity drafting") print() client = anthropic.Anthropic(api_key=api_key) pg_conn = psycopg2.connect(pg_dsn) results = [] started_at = datetime.now(timezone.utc).isoformat() t_total = time.time() for i, doc_meta in enumerate(sample, 1): source = doc_meta["source"] doc_text, original_len = fetch_document_text(pg_conn, source) if not doc_text: print(f"[{i:02d}/{len(sample)}] {source[:55]} — SKIP (not in pgvector)") results.append({"source": source, "skipped": "not_in_pgvector"}) continue sent_len = len(doc_text) truncated = original_len > sent_len size_bucket = ( "small" if sent_len < 1000 else "medium" if sent_len < 5000 else "large" ) trunc_marker = "*" if truncated else " " print(f"[{i:02d}/{len(sample)}] [{size_bucket:6s}] [{sent_len:>5}c{trunc_marker}] {source[:55]}", flush=True) # Condition A try: a = call_haiku(client, CONDITION_A_PROMPT + doc_text) a_ents, a_edges, a_ok = parse_graph_full(a["response_text"]) a_metrics = graph_metrics(a_ents, a_edges) if a_ok else None print(f" A: in={a['input_tokens']} out={a['output_tokens']} " f"stop={a['stop_reason']} t={a['latency_s']}s", flush=True) print(f" {fmt_metrics(a_metrics)}", flush=True) except Exception as e: print(f" A FAILED: {e}", flush=True) a = {"error": str(e)} a_metrics = None # Condition B local metadata pass local_result = call_local_metadata(doc_text) if "error" in local_result: print(f" B local FAILED: {local_result['error']}", flush=True) results.append({ "source": source, "size_bucket": size_bucket, "doc_chars_original": original_len, "doc_chars_sent": sent_len, "truncated": truncated, "condition_a": { "input_tokens": a.get("input_tokens"), "output_tokens": a.get("output_tokens"), "latency_s": a.get("latency_s"), "metrics": a_metrics, "stop_reason": a.get("stop_reason"), "response_text": a.get("response_text", "")[:4000], "error": a.get("error"), }, "condition_b": { "skipped": "local_model_failed", "local_error": local_result["error"], "local_latency_s": local_result.get("latency_s"), }, }) continue local_raw = local_result["response"] metadata = parse_metadata(local_raw) print(f" B local: t={local_result['latency_s']}s metadata_parsed={metadata is not None}", flush=True) if metadata is None: print(f" B: metadata parse failed — skipping API call", flush=True) results.append({ "source": source, "size_bucket": size_bucket, "doc_chars_original": original_len, "doc_chars_sent": sent_len, "truncated": truncated, "condition_a": { "input_tokens": a.get("input_tokens"), "output_tokens": a.get("output_tokens"), "latency_s": a.get("latency_s"), "metrics": a_metrics, "stop_reason": a.get("stop_reason"), "response_text": a.get("response_text", "")[:4000], "error": a.get("error"), }, "condition_b": { "skipped": "metadata_parse_failed", "local_latency_s": local_result.get("latency_s"), "local_raw": local_raw[:1000], }, }) continue metadata_json = json.dumps(metadata, ensure_ascii=False, indent=2) b_prompt = CONDITION_B_API_PROMPT.replace("{metadata_json}", metadata_json) + doc_text try: b = call_haiku(client, b_prompt) b_ents, b_edges, b_ok = parse_graph_full(b["response_text"]) b_metrics = graph_metrics(b_ents, b_edges) if b_ok else None print(f" B api: in={b['input_tokens']} out={b['output_tokens']} " f"stop={b['stop_reason']} t={b['latency_s']}s", flush=True) print(f" {fmt_metrics(b_metrics)}", flush=True) except Exception as e: print(f" B api FAILED: {e}", flush=True) b = {"error": str(e)} b_metrics = None # Per-doc deltas if "input_tokens" in a and "input_tokens" in b: in_pct = (b["input_tokens"] - a["input_tokens"]) / a["input_tokens"] * 100 if a["input_tokens"] else 0.0 out_pct = (b["output_tokens"] - a["output_tokens"]) / a["output_tokens"] * 100 if a["output_tokens"] else 0.0 edge_pct_str = "n/a" pred_pct_str = "n/a" if a_metrics and b_metrics: if a_metrics["n_edges"] > 0: edge_pct_str = f"{(b_metrics['n_edges'] - a_metrics['n_edges']) / a_metrics['n_edges'] * 100:+.1f}%" if a_metrics["predicate_diversity"] > 0: pred_pct_str = f"{(b_metrics['predicate_diversity'] - a_metrics['predicate_diversity']) / a_metrics['predicate_diversity'] * 100:+.1f}%" print(f" Δ in={in_pct:+.1f}% out={out_pct:+.1f}% edges={edge_pct_str} pred={pred_pct_str}", flush=True) results.append({ "source": source, "size_bucket": size_bucket, "doc_chars_original": original_len, "doc_chars_sent": sent_len, "truncated": truncated, "condition_a": { "input_tokens": a.get("input_tokens"), "output_tokens": a.get("output_tokens"), "latency_s": a.get("latency_s"), "metrics": a_metrics, "stop_reason": a.get("stop_reason"), "response_text": a.get("response_text", "")[:4000], "error": a.get("error"), }, "condition_b": { "local_latency_s": local_result.get("latency_s"), "local_metadata": metadata, "local_raw": local_raw[:1000], "api_input_tokens": b.get("input_tokens"), "api_output_tokens": b.get("output_tokens"), "api_latency_s": b.get("latency_s"), "metrics": b_metrics, "stop_reason": b.get("stop_reason"), "response_text": b.get("response_text", "")[:4000], "error": b.get("error"), }, }) pg_conn.close() total_elapsed = round(time.time() - t_total, 1) valid = [r for r in results if r.get("condition_a", {}).get("metrics") is not None and r.get("condition_b", {}).get("metrics") is not None] a_in = sum(r["condition_a"]["input_tokens"] for r in valid) a_out = sum(r["condition_a"]["output_tokens"] for r in valid) b_in = sum(r["condition_b"]["api_input_tokens"] for r in valid) b_out = sum(r["condition_b"]["api_output_tokens"] for r in valid) a_cost = (a_in * HAIKU_IN_PER_M + a_out * HAIKU_OUT_PER_M) / 1_000_000 b_cost = (b_in * HAIKU_IN_PER_M + b_out * HAIKU_OUT_PER_M) / 1_000_000 def avg_metric(rows, condition, key): vals = [r[condition]["metrics"][key] for r in rows if r[condition]["metrics"]] return round(statistics.mean(vals), 2) if vals else None by_bucket = {} for bucket in ("small", "medium", "large"): rows = [r for r in valid if r["size_bucket"] == bucket] if not rows: by_bucket[bucket] = None continue ai = sum(r["condition_a"]["input_tokens"] for r in rows) ao = sum(r["condition_a"]["output_tokens"] for r in rows) bi = sum(r["condition_b"]["api_input_tokens"] for r in rows) bo = sum(r["condition_b"]["api_output_tokens"] for r in rows) by_bucket[bucket] = { "n": len(rows), "input_delta_pct": round((bi - ai) / ai * 100, 2) if ai else None, "output_delta_pct": round((bo - ao) / ao * 100, 2) if ao else None, "a_avg_entities": avg_metric(rows, "condition_a", "n_entities"), "b_avg_entities": avg_metric(rows, "condition_b", "n_entities"), "a_avg_edges": avg_metric(rows, "condition_a", "n_edges"), "b_avg_edges": avg_metric(rows, "condition_b", "n_edges"), "a_avg_predicate_diversity": avg_metric(rows, "condition_a", "predicate_diversity"), "b_avg_predicate_diversity": avg_metric(rows, "condition_b", "predicate_diversity"), "a_avg_type_diversity": avg_metric(rows, "condition_a", "type_diversity"), "b_avg_type_diversity": avg_metric(rows, "condition_b", "type_diversity"), "a_avg_degree": avg_metric(rows, "condition_a", "avg_degree"), "b_avg_degree": avg_metric(rows, "condition_b", "avg_degree"), "a_avg_largest_component_pct": avg_metric(rows, "condition_a", "largest_component_pct"), "b_avg_largest_component_pct": avg_metric(rows, "condition_b", "largest_component_pct"), } summary = { "experiment": "base_class_test", "title": "Base-Class Enrichment — OOP Framing", "started_at": started_at, "completed_at": datetime.now(timezone.utc).isoformat(), "haiku_model": HAIKU_MODEL, "local_model": LOCAL_MODEL, "max_doc_chars": MAX_DOC_CHARS, "n_documents": len(sample), "n_valid_pairs": len(valid), "total_elapsed_s": total_elapsed, "totals": { "a_input_tokens": a_in, "a_output_tokens": a_out, "b_input_tokens": b_in, "b_output_tokens": b_out, "a_cost_usd": round(a_cost, 4), "b_cost_usd": round(b_cost, 4), "cost_delta_usd": round(b_cost - a_cost, 4), "cost_delta_pct": round((b_cost - a_cost) / a_cost * 100, 2) if a_cost else None, "note": "API cost only — local Mistral runtime on VPS not monetized", }, "by_size_bucket": by_bucket, "results": results, } OUTPUT_FILE.parent.mkdir(parents=True, exist_ok=True) with open(OUTPUT_FILE, "w") as f: json.dump(summary, f, indent=2) print() print("=" * 60) print(f"DONE — {len(valid)}/{len(sample)} valid pairs in {total_elapsed}s") print(f"A total cost: ${a_cost:.4f} (in={a_in} out={a_out})") print(f"B total cost: ${b_cost:.4f} (in={b_in} out={b_out})") delta_pct = summary['totals']['cost_delta_pct'] if delta_pct is not None: verdict = "B cheaper" if delta_pct < 0 else "B more expensive" print(f"Cost delta: {delta_pct:+.2f}% ({verdict})") print() print("By bucket — graph metrics (A vs B):") for bucket, stats in by_bucket.items(): if stats: print(f" {bucket:6s} (n={stats['n']}):") print(f" cost: in {stats['input_delta_pct']:+.1f}% out {stats['output_delta_pct']:+.1f}%") print(f" entities: A={stats['a_avg_entities']} B={stats['b_avg_entities']}") print(f" edges: A={stats['a_avg_edges']} B={stats['b_avg_edges']}") print(f" predicate diversity: A={stats['a_avg_predicate_diversity']} B={stats['b_avg_predicate_diversity']}") print(f" type diversity: A={stats['a_avg_type_diversity']} B={stats['b_avg_type_diversity']}") print(f" avg degree: A={stats['a_avg_degree']} B={stats['b_avg_degree']}") print(f" largest component %: A={stats['a_avg_largest_component_pct']} B={stats['b_avg_largest_component_pct']}") print() print(f"Results: {OUTPUT_FILE}") if __name__ == "__main__": main()