Files
aaronAI/scripts/cost_test_graphiti_bulk.py
T

180 lines
6.3 KiB
Python

"""
Measure actual Graphiti BULK episode cost on a stratified sample.
Uses /episodes/bulk endpoint. Submits in small batches to avoid rate limits.
"""
import json, os, random, time
from pathlib import Path
import psycopg2, requests
from dotenv import load_dotenv
load_dotenv(Path.home() / "aaronai" / ".env")
GRAPHITI_URL = "http://localhost:8001"
PG_DSN = os.environ["PG_DSN"]
SAMPLE_SIZE = 50
BATCH_SIZE = 5
RANDOM_SEED = 42
OUT = Path.home() / "aaronai" / "experiments" / "graphiti_bulk_cost_test.json"
OUT.parent.mkdir(parents=True, exist_ok=True)
def fetch_stratified_sample():
conn = psycopg2.connect(PG_DSN)
cur = conn.cursor()
cur.execute("""
SELECT source, STRING_AGG(document, E'\\n\\n' ORDER BY id) AS full_doc
FROM embeddings
GROUP BY source
""")
sources = [(s, doc) for s, doc in cur.fetchall() if doc]
cur.close(); conn.close()
random.seed(RANDOM_SEED)
short = [(s, d) for s, d in sources if len(d) < 1000]
medium = [(s, d) for s, d in sources if 1000 <= len(d) < 5000]
long_ = [(s, d) for s, d in sources if len(d) >= 5000]
print(f"Pool: short={len(short)} medium={len(medium)} long={len(long_)}")
sample = (
random.sample(short, min(15, len(short))) +
random.sample(medium, min(25, len(medium))) +
random.sample(long_, min(10, len(long_)))
)
print(f"Sample: {len(sample)} sources, batch_size={BATCH_SIZE}")
return sample
def submit_bulk_batch(batch):
payload = {
"episodes": [
{
"name": source,
"content": doc[:12000],
"source_description": "pgvector_migration_bulk_test",
"timestamp": "2026-04-28T00:00:00",
}
for source, doc in batch
]
}
t0 = time.time()
try:
r = requests.post(f"{GRAPHITI_URL}/episodes/bulk", json=payload, timeout=900)
elapsed = time.time() - t0
return {
"batch_size": len(batch),
"status_code": r.status_code,
"elapsed_s": round(elapsed, 2),
"elapsed_per_episode_s": round(elapsed / len(batch), 2),
"response": r.json() if r.ok else None,
"error": None if r.ok else r.text[:500],
"sources": [s for s, _ in batch],
}
except Exception as e:
return {
"batch_size": len(batch),
"status_code": None,
"elapsed_s": round(time.time() - t0, 2),
"elapsed_per_episode_s": None,
"response": None,
"error": str(e)[:500],
"sources": [s for s, _ in batch],
}
def main():
print("=" * 60)
print("Graphiti BULK Migration Cost Test (Haiku 4.5)")
print("=" * 60)
print()
print("BEFORE running:")
print(" 1. Open https://console.anthropic.com/settings/usage")
print(" 2. Note current spend.")
print()
input("Press Enter when noted... ")
print()
sample = fetch_stratified_sample()
if not sample:
print("ERROR: empty sample"); return
batches = [sample[i:i+BATCH_SIZE] for i in range(0, len(sample), BATCH_SIZE)]
print(f"Submitting {len(batches)} batches of up to {BATCH_SIZE} episodes")
print()
results = []
total_start = time.time()
for i, batch in enumerate(batches, start=1):
avg_chars = int(sum(len(d) for _, d in batch) / len(batch))
print(f"[batch {i:2d}/{len(batches)}] n={len(batch)} avg_chars={avg_chars:6d}",
end=" ", flush=True)
result = submit_bulk_batch(batch)
results.append(result)
if result["error"]:
print(f" ERROR: {result['error'][:80]}")
if "429" in (result["error"] or "") or "rate" in (result["error"] or "").lower():
print(" Rate limited - pausing 30s before next batch")
time.sleep(30)
else:
print(f" {result['status_code']} {result['elapsed_s']}s "
f"({result['elapsed_per_episode_s']}s/episode)")
total_elapsed = time.time() - total_start
successful_batches = [r for r in results if r["error"] is None]
failed_batches = [r for r in results if r["error"] is not None]
successful_episodes = sum(r["batch_size"] for r in successful_batches)
failed_episodes = sum(r["batch_size"] for r in failed_batches)
summary = {
"sample_size": len(sample),
"batch_size": BATCH_SIZE,
"n_batches": len(batches),
"successful_batches": len(successful_batches),
"failed_batches": len(failed_batches),
"successful_episodes": successful_episodes,
"failed_episodes": failed_episodes,
"total_elapsed_s": round(total_elapsed, 1),
"mean_elapsed_per_episode_s": round(
sum(r["elapsed_s"] for r in successful_batches) /
max(successful_episodes, 1), 2
) if successful_episodes else None,
"results": results,
}
conn = psycopg2.connect(PG_DSN)
cur = conn.cursor()
cur.execute("SELECT COUNT(DISTINCT source) FROM embeddings")
total_sources = cur.fetchone()[0]
cur.close(); conn.close()
summary["total_corpus_sources"] = total_sources
if summary["mean_elapsed_per_episode_s"]:
summary["estimated_migration_hours"] = round(
total_sources * summary["mean_elapsed_per_episode_s"] / 3600, 1
)
OUT.write_text(json.dumps(summary, indent=2))
print()
print("=" * 60)
print("RESULTS")
print("=" * 60)
print(f"Episodes: {summary['successful_episodes']}/{summary['sample_size']} succeeded")
print(f"Batches: {summary['successful_batches']}/{summary['n_batches']} succeeded")
print(f"Total elapsed: {summary['total_elapsed_s']}s")
if summary["mean_elapsed_per_episode_s"]:
print(f"Mean per episode: {summary['mean_elapsed_per_episode_s']}s")
print(f"Total corpus sources: {summary['total_corpus_sources']}")
print(f"Estimated migration runtime: {summary['estimated_migration_hours']} hours")
print()
print(f"AFTER:")
print(f" Wait 5 min; note new Anthropic spend; subtract from $28.61 baseline.")
print(f" delta / {summary['successful_episodes']} = per-episode cost")
print(f" per-episode * {summary['total_corpus_sources']} = full migration estimate")
print()
print(f"Full results: {OUT}")
if __name__ == "__main__":
main()