feat: stage2/3 pipeline, taxonomy-free cascade, E1.8/E4 experiments, corpus migration state
This commit is contained in:
@@ -0,0 +1,298 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tier 1 Graphiti Migration — pgvector to Graphiti for ~300 most-recent sources.
|
||||
Resumable via state file at ~/aaronai/experiments/tier1_migration_state.json.
|
||||
|
||||
Usage:
|
||||
python3 ~/aaronai/scripts/tier1_migration.py --dry-run
|
||||
python3 ~/aaronai/scripts/tier1_migration.py
|
||||
python3 ~/aaronai/scripts/tier1_migration.py --reset
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
import psycopg2
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(Path.home() / "aaronai" / ".env", override=True)
|
||||
|
||||
GRAPHITI_URL = "http://localhost:8001"
|
||||
PG_DSN = os.environ["PG_DSN"]
|
||||
|
||||
MAX_SOURCES = 300
|
||||
BATCH_SIZE = 4
|
||||
BATCH_DELAY_S = 5
|
||||
LONG_DOC_THRESHOLD = 5000
|
||||
LONG_DOC_BATCH_SIZE = 2
|
||||
|
||||
EXPERIMENTS = Path.home() / "aaronai" / "experiments"
|
||||
STATE_FILE = EXPERIMENTS / "tier1_migration_state.json"
|
||||
RESULTS_FILE = EXPERIMENTS / "tier1_migration_results.json"
|
||||
EXPERIMENTS.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def load_state():
|
||||
if STATE_FILE.exists():
|
||||
data = json.loads(STATE_FILE.read_text())
|
||||
return set(data.get("ingested", [])), data.get("started_at"), data.get("total_cost_estimate", 0)
|
||||
return set(), None, 0
|
||||
|
||||
|
||||
def save_state(ingested, started_at, total_cost_estimate):
|
||||
STATE_FILE.write_text(json.dumps({
|
||||
"ingested": sorted(ingested),
|
||||
"started_at": started_at,
|
||||
"last_updated": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||
"total_cost_estimate": round(total_cost_estimate, 4),
|
||||
"count": len(ingested),
|
||||
}, indent=2))
|
||||
|
||||
|
||||
def fetch_tier1_sources(cur, max_sources, exclude_set):
|
||||
cur.execute("""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = 'embeddings'
|
||||
""")
|
||||
columns = {r[0] for r in cur.fetchall()}
|
||||
has_created_at = "created_at" in columns
|
||||
|
||||
if has_created_at:
|
||||
order_clause = "MAX(created_at) DESC NULLS LAST"
|
||||
else:
|
||||
order_clause = "MAX(id) DESC"
|
||||
|
||||
cur.execute(f"""
|
||||
SELECT source,
|
||||
STRING_AGG(document, E'\n\n' ORDER BY id) AS full_doc,
|
||||
{("MAX(created_at)" if has_created_at else "NULL")} AS most_recent
|
||||
FROM embeddings
|
||||
GROUP BY source
|
||||
ORDER BY {order_clause}
|
||||
LIMIT %s
|
||||
""", (max_sources * 2,))
|
||||
|
||||
candidates = cur.fetchall()
|
||||
selected = []
|
||||
for source, doc, recent in candidates:
|
||||
if not doc:
|
||||
continue
|
||||
if source in exclude_set:
|
||||
continue
|
||||
selected.append((source, doc, recent))
|
||||
if len(selected) >= max_sources:
|
||||
break
|
||||
return selected
|
||||
|
||||
|
||||
def submit_batch(batch):
|
||||
payload = {
|
||||
"episodes": [
|
||||
{
|
||||
"name": source,
|
||||
"content": doc[:12000],
|
||||
"source_description": "tier1_migration",
|
||||
"timestamp": "2026-04-28T00:00:00",
|
||||
}
|
||||
for source, doc in batch
|
||||
]
|
||||
}
|
||||
t0 = time.time()
|
||||
try:
|
||||
r = requests.post(f"{GRAPHITI_URL}/episodes/bulk", json=payload, timeout=900)
|
||||
elapsed = time.time() - t0
|
||||
return {
|
||||
"ok": r.ok,
|
||||
"status_code": r.status_code,
|
||||
"elapsed_s": round(elapsed, 2),
|
||||
"error": None if r.ok else r.text[:500],
|
||||
"sources": [s for s, _ in batch],
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"ok": False,
|
||||
"status_code": None,
|
||||
"elapsed_s": round(time.time() - t0, 2),
|
||||
"error": str(e)[:500],
|
||||
"sources": [s for s, _ in batch],
|
||||
}
|
||||
|
||||
|
||||
def chunk_for_batches(sources, base_batch_size, long_threshold, long_batch_size):
|
||||
batch = []
|
||||
current_size_target = base_batch_size
|
||||
|
||||
for source, doc, _ in sources:
|
||||
is_long = len(doc) >= long_threshold
|
||||
target_size = long_batch_size if is_long else base_batch_size
|
||||
|
||||
if batch and target_size != current_size_target:
|
||||
yield batch
|
||||
batch = []
|
||||
current_size_target = target_size
|
||||
|
||||
batch.append((source, doc))
|
||||
if len(batch) >= current_size_target:
|
||||
yield batch
|
||||
batch = []
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--dry-run", action="store_true")
|
||||
parser.add_argument("--max-sources", type=int, default=MAX_SOURCES)
|
||||
parser.add_argument("--reset", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.reset and STATE_FILE.exists():
|
||||
confirm = input(f"Delete state file {STATE_FILE}? [y/N] ")
|
||||
if confirm.lower() == "y":
|
||||
STATE_FILE.unlink()
|
||||
print("State file deleted. Resuming from scratch.")
|
||||
else:
|
||||
print("Aborted.")
|
||||
return
|
||||
|
||||
print("=" * 70)
|
||||
print("Tier 1 Graphiti Migration")
|
||||
print("=" * 70)
|
||||
|
||||
try:
|
||||
r = requests.get(f"{GRAPHITI_URL}/health", timeout=10)
|
||||
if not r.ok:
|
||||
print(f"ERROR: sidecar /health returned {r.status_code}")
|
||||
return
|
||||
print(f"Sidecar: {r.json()}")
|
||||
except Exception as e:
|
||||
print(f"ERROR: sidecar unreachable: {e}")
|
||||
return
|
||||
|
||||
ingested, started_at, prior_cost = load_state()
|
||||
if started_at:
|
||||
print(f"Resuming run started at {started_at}")
|
||||
print(f" {len(ingested)} sources already ingested")
|
||||
print(f" Estimated cost so far: ${prior_cost:.2f}")
|
||||
else:
|
||||
started_at = time.strftime("%Y-%m-%dT%H:%M:%S")
|
||||
print(f"Fresh run starting at {started_at}")
|
||||
|
||||
print()
|
||||
print(f"Fetching tier 1 sources from pgvector (max={args.max_sources})...")
|
||||
conn = psycopg2.connect(PG_DSN)
|
||||
cur = conn.cursor()
|
||||
sources = fetch_tier1_sources(cur, args.max_sources, ingested)
|
||||
cur.close()
|
||||
conn.close()
|
||||
|
||||
print(f" {len(sources)} sources to ingest (excluding {len(ingested)} already done)")
|
||||
|
||||
if not sources:
|
||||
print()
|
||||
print("Nothing to do — all tier 1 sources already ingested.")
|
||||
return
|
||||
|
||||
short = sum(1 for _, d, _ in sources if len(d) < 1000)
|
||||
medium = sum(1 for _, d, _ in sources if 1000 <= len(d) < 5000)
|
||||
long_ = sum(1 for _, d, _ in sources if len(d) >= 5000)
|
||||
print(f" Distribution: short={short} medium={medium} long={long_}")
|
||||
print()
|
||||
|
||||
batches = list(chunk_for_batches(sources, BATCH_SIZE, LONG_DOC_THRESHOLD, LONG_DOC_BATCH_SIZE))
|
||||
print(f" Will submit {len(batches)} batches (delay {BATCH_DELAY_S}s between)")
|
||||
print()
|
||||
|
||||
if args.dry_run:
|
||||
print("DRY RUN - first 10 batches:")
|
||||
for i, batch in enumerate(batches[:10], 1):
|
||||
print(f" [{i}] n={len(batch)} sources={[s[:40] for s, _ in batch]}")
|
||||
print(f" ... and {max(0, len(batches) - 10)} more batches")
|
||||
print()
|
||||
print("Estimated cost: ${:.2f}".format(0.103 * sum(len(b) for b in batches)))
|
||||
print("Estimated runtime: {:.1f} hours".format(
|
||||
(sum(len(b) for b in batches) * 8 + len(batches) * BATCH_DELAY_S) / 3600
|
||||
))
|
||||
return
|
||||
|
||||
total_start = time.time()
|
||||
batch_results = []
|
||||
successful_episodes = 0
|
||||
failed_episodes = 0
|
||||
estimated_cost = prior_cost
|
||||
|
||||
for i, batch in enumerate(batches, 1):
|
||||
avg_chars = int(sum(len(d) for _, d in batch) / len(batch))
|
||||
bucket = "long" if avg_chars >= LONG_DOC_THRESHOLD else ("medium" if avg_chars >= 1000 else "short")
|
||||
print(f"[{i:3d}/{len(batches)}] [{bucket:6s}] n={len(batch)} avg={avg_chars:6d}c", end=" ", flush=True)
|
||||
|
||||
result = submit_batch(batch)
|
||||
batch_results.append(result)
|
||||
|
||||
if result["ok"]:
|
||||
print(f" 200 {result['elapsed_s']}s")
|
||||
for source, _ in batch:
|
||||
ingested.add(source)
|
||||
successful_episodes += len(batch)
|
||||
estimated_cost += 0.103 * len(batch)
|
||||
save_state(ingested, started_at, estimated_cost)
|
||||
else:
|
||||
err = (result["error"] or "")[:80]
|
||||
print(f" FAIL: {err}")
|
||||
failed_episodes += len(batch)
|
||||
save_state(ingested, started_at, estimated_cost)
|
||||
|
||||
if "Max pending queries" in (result["error"] or ""):
|
||||
print(f" FalkorDB queue overflow - pausing 30s")
|
||||
time.sleep(30)
|
||||
elif "timed out" in (result["error"] or "").lower():
|
||||
print(f" Query timeout - pausing 15s")
|
||||
time.sleep(15)
|
||||
elif "rate" in (result["error"] or "").lower() or "429" in (result["error"] or ""):
|
||||
print(f" Rate limited - pausing 60s")
|
||||
time.sleep(60)
|
||||
|
||||
if i < len(batches):
|
||||
time.sleep(BATCH_DELAY_S)
|
||||
|
||||
total_elapsed = time.time() - total_start
|
||||
|
||||
summary = {
|
||||
"started_at": started_at,
|
||||
"completed_at": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||
"total_elapsed_s": round(total_elapsed, 1),
|
||||
"total_elapsed_hours": round(total_elapsed / 3600, 2),
|
||||
"n_batches": len(batches),
|
||||
"successful_episodes": successful_episodes,
|
||||
"failed_episodes": failed_episodes,
|
||||
"total_ingested_now": len(ingested),
|
||||
"estimated_total_cost": round(estimated_cost, 2),
|
||||
"batch_results": batch_results,
|
||||
}
|
||||
RESULTS_FILE.write_text(json.dumps(summary, indent=2))
|
||||
|
||||
print()
|
||||
print("=" * 70)
|
||||
print("TIER 1 MIGRATION COMPLETE")
|
||||
print("=" * 70)
|
||||
print(f"Successful episodes: {successful_episodes}/{successful_episodes + failed_episodes}")
|
||||
print(f"Failed episodes: {failed_episodes}")
|
||||
print(f"Total ingested now: {len(ingested)}")
|
||||
print(f"Wall-clock: {total_elapsed/3600:.2f} hours")
|
||||
print(f"Estimated cost: ${estimated_cost:.2f}")
|
||||
print()
|
||||
print(f"State file: {STATE_FILE}")
|
||||
print(f"Results file: {RESULTS_FILE}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print()
|
||||
print("Interrupted. State saved. Re-run to resume.")
|
||||
sys.exit(130)
|
||||
Reference in New Issue
Block a user