117 lines
3.9 KiB
Python
117 lines
3.9 KiB
Python
#!/usr/bin/env python3
|
|
"""E1 sample selection — pick 10 episodes from Tier 1 stratified by density and type."""
|
|
import json
|
|
import os
|
|
import subprocess
|
|
from pathlib import Path
|
|
from collections import defaultdict
|
|
|
|
EXPERIMENTS = Path.home() / "aaronai" / "experiments"
|
|
OUTPUT = EXPERIMENTS / "cascade_reextract_sample.json"
|
|
|
|
# Get all Tier 1 episodes with their entity counts via FalkorDB
|
|
def query_episode_counts():
|
|
query = ("MATCH (e:Episodic) OPTIONAL MATCH (e)-[r]-(n:Entity) "
|
|
"RETURN e.name AS name, count(distinct n) AS entities "
|
|
"ORDER BY entities DESC")
|
|
result = subprocess.run(
|
|
["docker", "exec", "falkordb", "redis-cli", "GRAPH.QUERY", "aaron", query],
|
|
capture_output=True, text=True
|
|
)
|
|
# Parse the output — redis-cli returns rows after a header
|
|
lines = [l for l in result.stdout.split("\n") if l.strip()]
|
|
episodes = []
|
|
# Skip header rows ("name", "entities") and timing rows
|
|
i = 0
|
|
while i < len(lines):
|
|
if lines[i] == "name":
|
|
i += 2 # skip "name" and "entities" headers
|
|
continue
|
|
if lines[i].startswith("Cached") or lines[i].startswith("Query"):
|
|
break
|
|
# Each episode: name on one line, count on next
|
|
if i + 1 < len(lines):
|
|
try:
|
|
count = int(lines[i + 1])
|
|
episodes.append({"name": lines[i], "entities": count})
|
|
i += 2
|
|
except ValueError:
|
|
i += 1
|
|
else:
|
|
i += 1
|
|
return episodes
|
|
|
|
print("Fetching episode entity counts from FalkorDB...")
|
|
episodes = query_episode_counts()
|
|
print(f"Got {len(episodes)} episodes")
|
|
|
|
# Classify by density bucket and type
|
|
def is_document(name):
|
|
doc_extensions = (".pdf", ".docx", ".pptx", ".txt", ".md")
|
|
return any(name.lower().endswith(ext) for ext in doc_extensions)
|
|
|
|
# Compute quartile boundaries from the entity counts
|
|
counts = sorted([e["entities"] for e in episodes], reverse=True)
|
|
n = len(counts)
|
|
top_q = counts[n // 4] # 25th percentile from top
|
|
bottom_q = counts[3 * n // 4] # 75th percentile from top
|
|
|
|
print(f"\nQuartile boundaries: top={top_q}+, middle=({bottom_q+1}-{top_q-1}), bottom=0-{bottom_q}")
|
|
|
|
high = [e for e in episodes if e["entities"] >= top_q and not is_document(e["name"])]
|
|
mid = [e for e in episodes if bottom_q < e["entities"] < top_q and not is_document(e["name"])]
|
|
low = [e for e in episodes if e["entities"] <= bottom_q and not is_document(e["name"])]
|
|
docs = [e for e in episodes if is_document(e["name"]) and e["entities"] >= 5]
|
|
|
|
print(f"High-density conversations: {len(high)}")
|
|
print(f"Mid-density conversations: {len(mid)}")
|
|
print(f"Low-density conversations: {len(low)}")
|
|
print(f"Documents (≥5 entities): {len(docs)}")
|
|
|
|
# Deterministic selection — take from middle of each bucket to avoid edge cases
|
|
def pick(bucket, n):
|
|
if len(bucket) < n:
|
|
return bucket
|
|
mid_idx = len(bucket) // 2
|
|
start = max(0, mid_idx - n // 2)
|
|
return bucket[start:start + n]
|
|
|
|
selected = (
|
|
pick(high, 3) +
|
|
pick(mid, 3) +
|
|
pick(low, 2) +
|
|
pick(docs, 2)
|
|
)
|
|
|
|
# Tag each with its bucket
|
|
def bucket_for(ep):
|
|
if is_document(ep["name"]):
|
|
return "document"
|
|
if ep["entities"] >= top_q:
|
|
return "high"
|
|
if ep["entities"] > bottom_q:
|
|
return "mid"
|
|
return "low"
|
|
|
|
for ep in selected:
|
|
ep["bucket"] = bucket_for(ep)
|
|
|
|
print(f"\nSelected {len(selected)} episodes for E1:")
|
|
for ep in selected:
|
|
print(f" [{ep['bucket']:>8}] {ep['entities']:>3}e {ep['name']}")
|
|
|
|
# Save selection
|
|
with open(OUTPUT, "w") as f:
|
|
json.dump({
|
|
"metadata": {
|
|
"purpose": "E1 cascade re-extraction sample (n=10)",
|
|
"stratification": "density buckets + document subset",
|
|
"quartile_top": top_q,
|
|
"quartile_bottom": bottom_q,
|
|
"total_tier1_episodes": len(episodes),
|
|
},
|
|
"selected": selected,
|
|
}, f, indent=2)
|
|
|
|
print(f"\nSaved to {OUTPUT}")
|