Migrate to pgvector — remove ChromaDB from api.py, ingest scripts, dream.py
This commit is contained in:
+33
-36
@@ -10,8 +10,13 @@ import argparse
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
from dotenv import load_dotenv
|
||||
import psycopg2
|
||||
|
||||
load_dotenv(Path.home() / "aaronai" / ".env")
|
||||
PG_DSN = os.getenv("PG_DSN", "dbname=aaronai user=aaronai password=aaronai_db_password host=localhost")
|
||||
|
||||
def get_pg():
|
||||
return psycopg2.connect(PG_DSN)
|
||||
|
||||
# ─── Paths ──────────────────────────────────────────────────────────────────
|
||||
DB_PATH = str(Path.home() / "aaronai" / "db")
|
||||
@@ -115,16 +120,9 @@ def check_recent_journal(days=3):
|
||||
# ─── Stage 3: Retrieve ──────────────────────────────────────────────────────
|
||||
|
||||
def retrieve(mode, task=None, project=None, n_results=8):
|
||||
import chromadb
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
client = chromadb.PersistentClient(path=DB_PATH)
|
||||
collection = client.get_or_create_collection(
|
||||
name="aaronai",
|
||||
metadata={"hnsw:space": "cosine", "hnsw:allow_replace_deleted": True}
|
||||
)
|
||||
|
||||
embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
low, high = MODE_RANGES[mode]
|
||||
|
||||
if task:
|
||||
@@ -138,38 +136,37 @@ def retrieve(mode, task=None, project=None, n_results=8):
|
||||
else:
|
||||
query = "research fabrication teaching practice recent work"
|
||||
|
||||
embedding = embedder.encode([query]).tolist()
|
||||
results = collection.query(
|
||||
query_embeddings=embedding,
|
||||
n_results=n_results * 3,
|
||||
include=["documents", "metadatas", "distances"]
|
||||
)
|
||||
embedding = embedder.encode([query]).tolist()[0]
|
||||
|
||||
chunks = []
|
||||
chunks = []
|
||||
seen_sources = set()
|
||||
|
||||
for doc, meta, dist in zip(
|
||||
results["documents"][0],
|
||||
results["metadatas"][0],
|
||||
results["distances"][0]
|
||||
):
|
||||
relevance = 1 - dist
|
||||
source = meta.get("source", "unknown")
|
||||
try:
|
||||
pg = get_pg()
|
||||
cur = pg.cursor()
|
||||
cur.execute("""
|
||||
SELECT document, source, 1 - (embedding <=> %s::vector) as similarity
|
||||
FROM embeddings
|
||||
ORDER BY embedding <=> %s::vector
|
||||
LIMIT %s
|
||||
""", (embedding, embedding, n_results * 3))
|
||||
|
||||
if not (low <= relevance <= high):
|
||||
continue
|
||||
if source in seen_sources:
|
||||
continue
|
||||
|
||||
chunks.append({
|
||||
"source": source,
|
||||
"content": doc,
|
||||
"relevance": relevance,
|
||||
})
|
||||
seen_sources.add(source)
|
||||
|
||||
if len(chunks) >= n_results:
|
||||
break
|
||||
for doc, source, similarity in cur.fetchall():
|
||||
if not (low <= similarity <= high):
|
||||
continue
|
||||
if source in seen_sources:
|
||||
continue
|
||||
chunks.append({
|
||||
"source": source or "unknown",
|
||||
"content": doc,
|
||||
"relevance": similarity,
|
||||
})
|
||||
seen_sources.add(source)
|
||||
if len(chunks) >= n_results:
|
||||
break
|
||||
pg.close()
|
||||
except Exception as e:
|
||||
print(f"pgvector retrieval error: {e}")
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
Reference in New Issue
Block a user