299 lines
9.7 KiB
Python
299 lines
9.7 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Tier 1 Graphiti Migration — pgvector to Graphiti for ~300 most-recent sources.
|
|
Resumable via state file at ~/aaronai/experiments/tier1_migration_state.json.
|
|
|
|
Usage:
|
|
python3 ~/aaronai/scripts/tier1_migration.py --dry-run
|
|
python3 ~/aaronai/scripts/tier1_migration.py
|
|
python3 ~/aaronai/scripts/tier1_migration.py --reset
|
|
"""
|
|
import argparse
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
import psycopg2
|
|
import requests
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv(Path.home() / "aaronai" / ".env", override=True)
|
|
|
|
GRAPHITI_URL = "http://localhost:8001"
|
|
PG_DSN = os.environ["PG_DSN"]
|
|
|
|
MAX_SOURCES = 300
|
|
BATCH_SIZE = 4
|
|
BATCH_DELAY_S = 5
|
|
LONG_DOC_THRESHOLD = 5000
|
|
LONG_DOC_BATCH_SIZE = 2
|
|
|
|
EXPERIMENTS = Path.home() / "aaronai" / "experiments"
|
|
STATE_FILE = EXPERIMENTS / "tier1_migration_state.json"
|
|
RESULTS_FILE = EXPERIMENTS / "tier1_migration_results.json"
|
|
EXPERIMENTS.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
def load_state():
|
|
if STATE_FILE.exists():
|
|
data = json.loads(STATE_FILE.read_text())
|
|
return set(data.get("ingested", [])), data.get("started_at"), data.get("total_cost_estimate", 0)
|
|
return set(), None, 0
|
|
|
|
|
|
def save_state(ingested, started_at, total_cost_estimate):
|
|
STATE_FILE.write_text(json.dumps({
|
|
"ingested": sorted(ingested),
|
|
"started_at": started_at,
|
|
"last_updated": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
|
"total_cost_estimate": round(total_cost_estimate, 4),
|
|
"count": len(ingested),
|
|
}, indent=2))
|
|
|
|
|
|
def fetch_tier1_sources(cur, max_sources, exclude_set):
|
|
cur.execute("""
|
|
SELECT column_name FROM information_schema.columns
|
|
WHERE table_name = 'embeddings'
|
|
""")
|
|
columns = {r[0] for r in cur.fetchall()}
|
|
has_created_at = "created_at" in columns
|
|
|
|
if has_created_at:
|
|
order_clause = "MAX(created_at) DESC NULLS LAST"
|
|
else:
|
|
order_clause = "MAX(id) DESC"
|
|
|
|
cur.execute(f"""
|
|
SELECT source,
|
|
STRING_AGG(document, E'\n\n' ORDER BY id) AS full_doc,
|
|
{("MAX(created_at)" if has_created_at else "NULL")} AS most_recent
|
|
FROM embeddings
|
|
GROUP BY source
|
|
ORDER BY {order_clause}
|
|
LIMIT %s
|
|
""", (max_sources * 2,))
|
|
|
|
candidates = cur.fetchall()
|
|
selected = []
|
|
for source, doc, recent in candidates:
|
|
if not doc:
|
|
continue
|
|
if source in exclude_set:
|
|
continue
|
|
selected.append((source, doc, recent))
|
|
if len(selected) >= max_sources:
|
|
break
|
|
return selected
|
|
|
|
|
|
def submit_batch(batch):
|
|
payload = {
|
|
"episodes": [
|
|
{
|
|
"name": source,
|
|
"content": doc[:12000],
|
|
"source_description": "tier1_migration",
|
|
"timestamp": "2026-04-28T00:00:00",
|
|
}
|
|
for source, doc in batch
|
|
]
|
|
}
|
|
t0 = time.time()
|
|
try:
|
|
r = requests.post(f"{GRAPHITI_URL}/episodes/bulk", json=payload, timeout=900)
|
|
elapsed = time.time() - t0
|
|
return {
|
|
"ok": r.ok,
|
|
"status_code": r.status_code,
|
|
"elapsed_s": round(elapsed, 2),
|
|
"error": None if r.ok else r.text[:500],
|
|
"sources": [s for s, _ in batch],
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
"ok": False,
|
|
"status_code": None,
|
|
"elapsed_s": round(time.time() - t0, 2),
|
|
"error": str(e)[:500],
|
|
"sources": [s for s, _ in batch],
|
|
}
|
|
|
|
|
|
def chunk_for_batches(sources, base_batch_size, long_threshold, long_batch_size):
|
|
batch = []
|
|
current_size_target = base_batch_size
|
|
|
|
for source, doc, _ in sources:
|
|
is_long = len(doc) >= long_threshold
|
|
target_size = long_batch_size if is_long else base_batch_size
|
|
|
|
if batch and target_size != current_size_target:
|
|
yield batch
|
|
batch = []
|
|
current_size_target = target_size
|
|
|
|
batch.append((source, doc))
|
|
if len(batch) >= current_size_target:
|
|
yield batch
|
|
batch = []
|
|
|
|
if batch:
|
|
yield batch
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--dry-run", action="store_true")
|
|
parser.add_argument("--max-sources", type=int, default=MAX_SOURCES)
|
|
parser.add_argument("--reset", action="store_true")
|
|
args = parser.parse_args()
|
|
|
|
if args.reset and STATE_FILE.exists():
|
|
confirm = input(f"Delete state file {STATE_FILE}? [y/N] ")
|
|
if confirm.lower() == "y":
|
|
STATE_FILE.unlink()
|
|
print("State file deleted. Resuming from scratch.")
|
|
else:
|
|
print("Aborted.")
|
|
return
|
|
|
|
print("=" * 70)
|
|
print("Tier 1 Graphiti Migration")
|
|
print("=" * 70)
|
|
|
|
try:
|
|
r = requests.get(f"{GRAPHITI_URL}/health", timeout=10)
|
|
if not r.ok:
|
|
print(f"ERROR: sidecar /health returned {r.status_code}")
|
|
return
|
|
print(f"Sidecar: {r.json()}")
|
|
except Exception as e:
|
|
print(f"ERROR: sidecar unreachable: {e}")
|
|
return
|
|
|
|
ingested, started_at, prior_cost = load_state()
|
|
if started_at:
|
|
print(f"Resuming run started at {started_at}")
|
|
print(f" {len(ingested)} sources already ingested")
|
|
print(f" Estimated cost so far: ${prior_cost:.2f}")
|
|
else:
|
|
started_at = time.strftime("%Y-%m-%dT%H:%M:%S")
|
|
print(f"Fresh run starting at {started_at}")
|
|
|
|
print()
|
|
print(f"Fetching tier 1 sources from pgvector (max={args.max_sources})...")
|
|
conn = psycopg2.connect(PG_DSN)
|
|
cur = conn.cursor()
|
|
sources = fetch_tier1_sources(cur, args.max_sources, ingested)
|
|
cur.close()
|
|
conn.close()
|
|
|
|
print(f" {len(sources)} sources to ingest (excluding {len(ingested)} already done)")
|
|
|
|
if not sources:
|
|
print()
|
|
print("Nothing to do — all tier 1 sources already ingested.")
|
|
return
|
|
|
|
short = sum(1 for _, d, _ in sources if len(d) < 1000)
|
|
medium = sum(1 for _, d, _ in sources if 1000 <= len(d) < 5000)
|
|
long_ = sum(1 for _, d, _ in sources if len(d) >= 5000)
|
|
print(f" Distribution: short={short} medium={medium} long={long_}")
|
|
print()
|
|
|
|
batches = list(chunk_for_batches(sources, BATCH_SIZE, LONG_DOC_THRESHOLD, LONG_DOC_BATCH_SIZE))
|
|
print(f" Will submit {len(batches)} batches (delay {BATCH_DELAY_S}s between)")
|
|
print()
|
|
|
|
if args.dry_run:
|
|
print("DRY RUN - first 10 batches:")
|
|
for i, batch in enumerate(batches[:10], 1):
|
|
print(f" [{i}] n={len(batch)} sources={[s[:40] for s, _ in batch]}")
|
|
print(f" ... and {max(0, len(batches) - 10)} more batches")
|
|
print()
|
|
print("Estimated cost: ${:.2f}".format(0.103 * sum(len(b) for b in batches)))
|
|
print("Estimated runtime: {:.1f} hours".format(
|
|
(sum(len(b) for b in batches) * 8 + len(batches) * BATCH_DELAY_S) / 3600
|
|
))
|
|
return
|
|
|
|
total_start = time.time()
|
|
batch_results = []
|
|
successful_episodes = 0
|
|
failed_episodes = 0
|
|
estimated_cost = prior_cost
|
|
|
|
for i, batch in enumerate(batches, 1):
|
|
avg_chars = int(sum(len(d) for _, d in batch) / len(batch))
|
|
bucket = "long" if avg_chars >= LONG_DOC_THRESHOLD else ("medium" if avg_chars >= 1000 else "short")
|
|
print(f"[{i:3d}/{len(batches)}] [{bucket:6s}] n={len(batch)} avg={avg_chars:6d}c", end=" ", flush=True)
|
|
|
|
result = submit_batch(batch)
|
|
batch_results.append(result)
|
|
|
|
if result["ok"]:
|
|
print(f" 200 {result['elapsed_s']}s")
|
|
for source, _ in batch:
|
|
ingested.add(source)
|
|
successful_episodes += len(batch)
|
|
estimated_cost += 0.103 * len(batch)
|
|
save_state(ingested, started_at, estimated_cost)
|
|
else:
|
|
err = (result["error"] or "")[:80]
|
|
print(f" FAIL: {err}")
|
|
failed_episodes += len(batch)
|
|
save_state(ingested, started_at, estimated_cost)
|
|
|
|
if "Max pending queries" in (result["error"] or ""):
|
|
print(f" FalkorDB queue overflow - pausing 30s")
|
|
time.sleep(30)
|
|
elif "timed out" in (result["error"] or "").lower():
|
|
print(f" Query timeout - pausing 15s")
|
|
time.sleep(15)
|
|
elif "rate" in (result["error"] or "").lower() or "429" in (result["error"] or ""):
|
|
print(f" Rate limited - pausing 60s")
|
|
time.sleep(60)
|
|
|
|
if i < len(batches):
|
|
time.sleep(BATCH_DELAY_S)
|
|
|
|
total_elapsed = time.time() - total_start
|
|
|
|
summary = {
|
|
"started_at": started_at,
|
|
"completed_at": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
|
"total_elapsed_s": round(total_elapsed, 1),
|
|
"total_elapsed_hours": round(total_elapsed / 3600, 2),
|
|
"n_batches": len(batches),
|
|
"successful_episodes": successful_episodes,
|
|
"failed_episodes": failed_episodes,
|
|
"total_ingested_now": len(ingested),
|
|
"estimated_total_cost": round(estimated_cost, 2),
|
|
"batch_results": batch_results,
|
|
}
|
|
RESULTS_FILE.write_text(json.dumps(summary, indent=2))
|
|
|
|
print()
|
|
print("=" * 70)
|
|
print("TIER 1 MIGRATION COMPLETE")
|
|
print("=" * 70)
|
|
print(f"Successful episodes: {successful_episodes}/{successful_episodes + failed_episodes}")
|
|
print(f"Failed episodes: {failed_episodes}")
|
|
print(f"Total ingested now: {len(ingested)}")
|
|
print(f"Wall-clock: {total_elapsed/3600:.2f} hours")
|
|
print(f"Estimated cost: ${estimated_cost:.2f}")
|
|
print()
|
|
print(f"State file: {STATE_FILE}")
|
|
print(f"Results file: {RESULTS_FILE}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
main()
|
|
except KeyboardInterrupt:
|
|
print()
|
|
print("Interrupted. State saved. Re-run to resume.")
|
|
sys.exit(130)
|