209 lines
7.2 KiB
Python
209 lines
7.2 KiB
Python
#!/usr/bin/env python3
|
|
"""E1.4 orchestration — cascade re-extraction at n=30, group_id=aaron_cascade_e14."""
|
|
import json
|
|
import os
|
|
import requests
|
|
import time
|
|
from pathlib import Path
|
|
import psycopg2
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv(Path.home() / "aaronai" / ".env")
|
|
|
|
EXPERIMENTS = Path.home() / "aaronai" / "experiments"
|
|
SAMPLE_FILE = EXPERIMENTS / "e14_sample.json"
|
|
RESULTS_FILE = EXPERIMENTS / "e14_cascade_results.json"
|
|
PG_DSN = os.environ["PG_DSN"]
|
|
SIDECAR_URL = "http://localhost:8001"
|
|
TEST_GROUP_ID = "aaron_cascade_e14"
|
|
MAX_DOC_CHARS = 12000
|
|
|
|
METADATA_PROMPT = """You are a metadata extraction system. Given a document, produce structural and content metadata in strict JSON format.
|
|
|
|
Do not summarize the content beyond the one-sentence summary field. Do not extract entities or relationships. Do not interpret meaning. Produce only the metadata schema below.
|
|
|
|
Output JSON only. No prose, no explanation, no markdown code fences.
|
|
|
|
Schema:
|
|
{
|
|
"language": "<ISO 639-1 code>",
|
|
"char_length": <integer>,
|
|
"primary_format": "<prose|slides|code|structured|mixed>",
|
|
"structural_signals": {
|
|
"has_headings": <boolean>,
|
|
"has_bullet_lists": <boolean>,
|
|
"has_numbered_lists": <boolean>,
|
|
"has_tables": <boolean>,
|
|
"has_code_blocks": <boolean>,
|
|
"has_dates": <boolean>
|
|
},
|
|
"content_signals": {
|
|
"has_named_people": <boolean>,
|
|
"has_institutional_language": <boolean>,
|
|
"has_technical_terminology": <boolean>,
|
|
"has_first_person": <boolean>,
|
|
"has_quotations": <boolean>
|
|
},
|
|
"domain_class": "<technical|administrative|educational|personal|conversational>",
|
|
"one_sentence_summary": "<one sentence describing what the document is about>"
|
|
}
|
|
|
|
Document:
|
|
"""
|
|
|
|
|
|
def get_pg():
|
|
return psycopg2.connect(PG_DSN)
|
|
|
|
|
|
def fetch_source_text(source):
|
|
conn = get_pg()
|
|
cur = conn.cursor()
|
|
cur.execute("""
|
|
SELECT STRING_AGG(document, E'\n\n' ORDER BY id) AS full_doc
|
|
FROM embeddings WHERE source = %s
|
|
""", (source,))
|
|
row = cur.fetchone()
|
|
conn.close()
|
|
if row is None or row[0] is None:
|
|
return None
|
|
return row[0]
|
|
|
|
|
|
def run_mistral_metadata(text, max_retries=2):
|
|
truncated = text[:MAX_DOC_CHARS]
|
|
prompt = METADATA_PROMPT + truncated
|
|
last_err = None
|
|
for attempt in range(max_retries):
|
|
try:
|
|
response = requests.post(
|
|
"http://localhost:11434/api/generate",
|
|
json={"model": "mistral:latest", "prompt": prompt, "stream": False, "format": "json"},
|
|
timeout=300,
|
|
)
|
|
response.raise_for_status()
|
|
raw = response.json()["response"]
|
|
try:
|
|
metadata = json.loads(raw)
|
|
metadata["char_length"] = len(truncated)
|
|
return metadata
|
|
except json.JSONDecodeError:
|
|
return {"error": "JSON parse failed", "raw": raw[:500]}
|
|
except (requests.exceptions.ReadTimeout, requests.exceptions.ConnectionError) as e:
|
|
last_err = e
|
|
if attempt < max_retries - 1:
|
|
print(f" (retry {attempt+1} after {type(e).__name__})", end=" ", flush=True)
|
|
time.sleep(5)
|
|
continue
|
|
return {"error": f"After {max_retries} retries: {last_err}"}
|
|
|
|
|
|
def format_metadata_as_orientation(metadata):
|
|
if "error" in metadata:
|
|
return None
|
|
summary = metadata.get("one_sentence_summary", "")
|
|
domain = metadata.get("domain_class", "unknown")
|
|
fmt = metadata.get("primary_format", "unknown")
|
|
return (
|
|
f"This is a {domain} document in {fmt} format. "
|
|
f"Summary: {summary} "
|
|
f"This metadata is provided to orient your extraction, not to constrain it. "
|
|
f"Extract entities and relationships freely from the document text itself; "
|
|
f"the metadata is descriptive context, not a checklist."
|
|
)
|
|
|
|
|
|
def submit_episode_singular(name, content, custom_instructions):
|
|
payload = {
|
|
"name": name,
|
|
"content": content[:MAX_DOC_CHARS],
|
|
"source_description": "e14_replication_run",
|
|
"timestamp": "2026-04-29T00:00:00",
|
|
"group_id": TEST_GROUP_ID,
|
|
"custom_extraction_instructions": custom_instructions,
|
|
}
|
|
response = requests.post(f"{SIDECAR_URL}/episodes", json=payload, timeout=300)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
|
|
def load_state():
|
|
if RESULTS_FILE.exists():
|
|
with open(RESULTS_FILE) as f:
|
|
data = json.load(f)
|
|
return data.get("results", []), {r["name"] for r in data.get("results", []) if "submit_result" in r}
|
|
return [], set()
|
|
|
|
|
|
def main():
|
|
with open(SAMPLE_FILE) as f:
|
|
sample = json.load(f)
|
|
selected = sample["selected"]
|
|
|
|
results, completed = load_state()
|
|
if completed:
|
|
print(f"Resuming — {len(completed)} sources already completed, {len(selected) - len(completed)} remaining\n")
|
|
else:
|
|
print(f"E1.4 cascade replication — {len(selected)} episodes to group_id={TEST_GROUP_ID}\n")
|
|
|
|
for i, ep in enumerate(selected, 1):
|
|
name = ep["name"]
|
|
bucket = ep["bucket"]
|
|
if name in completed:
|
|
print(f"[{i}/{len(selected)}] [{bucket}] {name} — SKIP (already completed)")
|
|
continue
|
|
|
|
print(f"[{i}/{len(selected)}] [{bucket}] {name}")
|
|
record = {"name": name, "bucket": bucket, "tier1_entities": ep["entities"]}
|
|
if ep.get("subtype"):
|
|
record["subtype"] = ep["subtype"]
|
|
|
|
print(f" Fetching source text...", end=" ", flush=True)
|
|
text = fetch_source_text(name)
|
|
if text is None:
|
|
print("FAILED — no chunks in pgvector")
|
|
record["error"] = "no source text"
|
|
results.append(record)
|
|
with open(RESULTS_FILE, "w") as f:
|
|
json.dump({"results": results}, f, indent=2, default=str)
|
|
continue
|
|
record["doc_chars"] = len(text)
|
|
print(f"{len(text)} chars")
|
|
|
|
print(f" Generating Mistral metadata...", end=" ", flush=True)
|
|
t0 = time.time()
|
|
metadata = run_mistral_metadata(text)
|
|
elapsed = time.time() - t0
|
|
record["metadata"] = metadata
|
|
record["metadata_elapsed_s"] = round(elapsed, 1)
|
|
if "error" in metadata:
|
|
print(f"FAILED in {elapsed:.1f}s")
|
|
else:
|
|
print(f"{elapsed:.1f}s — domain={metadata.get('domain_class')}, format={metadata.get('primary_format')}")
|
|
|
|
custom_instructions = format_metadata_as_orientation(metadata)
|
|
record["custom_extraction_instructions"] = custom_instructions
|
|
print(f" Submitting via /episodes...", end=" ", flush=True)
|
|
t0 = time.time()
|
|
try:
|
|
result = submit_episode_singular(name, text, custom_instructions)
|
|
elapsed = time.time() - t0
|
|
print(f"{elapsed:.1f}s — OK")
|
|
record["submit_elapsed_s"] = round(elapsed, 1)
|
|
record["submit_result"] = result
|
|
except Exception as e:
|
|
elapsed = time.time() - t0
|
|
print(f"{elapsed:.1f}s — FAILED: {e}")
|
|
record["submit_error"] = str(e)
|
|
|
|
results.append(record)
|
|
with open(RESULTS_FILE, "w") as f:
|
|
json.dump({"results": results}, f, indent=2, default=str)
|
|
print()
|
|
|
|
print(f"\nDone. Results saved to {RESULTS_FILE}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|