Migrate to pgvector — remove ChromaDB from api.py, ingest scripts, dream.py

This commit is contained in:
2026-04-26 21:16:04 +00:00
parent d2eed98906
commit f78b83042b
6 changed files with 250 additions and 83 deletions
+44 -26
View File
@@ -6,10 +6,11 @@ import hashlib
from pathlib import Path
from datetime import datetime
from dotenv import load_dotenv
import chromadb
from sentence_transformers import SentenceTransformer
import anthropic
from fastapi import FastAPI, Request, Response, Depends, HTTPException
import psycopg2
import psycopg2.extras
from fastapi import UploadFile, File
import tempfile
import os
@@ -46,6 +47,10 @@ DEFAULT_SETTINGS = {
}
print("Loading Aaron AI...")
PG_DSN = os.getenv("PG_DSN", "dbname=aaronai user=aaronai password=aaronai_db_password host=localhost")
def get_pg():
return psycopg2.connect(PG_DSN)
WHISPER_PROMPT = (
"Grasshopper, Rhino, PolyJet, SLA, FDM, DMLS, ChromaDB, "
"HVAMC, FWN3D, Mossygear, Nextcloud, Gitea, computational design, "
@@ -59,11 +64,7 @@ if HAS_WHISPER:
except Exception as e:
print(f"Whisper not available: {e}")
embedder = SentenceTransformer("all-MiniLM-L6-v2")
chroma_client = chromadb.PersistentClient(path=DB_PATH)
collection = chroma_client.get_or_create_collection(
name="aaronai",
metadata={"hnsw:space": "cosine", "hnsw:allow_replace_deleted": True}
)
# ChromaDB removed — using pgvector
anthropic_client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
SYSTEM_PROMPT = """You are the personal AI assistant of Aaron Nelson — computational
@@ -210,11 +211,17 @@ def remove_from_memory(item):
def get_pinned_cv_context():
try:
results = collection.get(
where={"source": {"$in": CV_SOURCES}},
include=["documents", "metadatas"]
pg = get_pg()
cur = pg.cursor()
cur.execute(
"SELECT document, source FROM embeddings WHERE source = ANY(%s)",
(CV_SOURCES,)
)
return results["documents"], results["metadatas"]
rows = cur.fetchall()
pg.close()
docs = [r[0] for r in rows]
metas = [{"source": r[1]} for r in rows]
return docs, metas
except:
return [], []
@@ -227,12 +234,7 @@ def is_professional_query(query):
return any(k in query.lower() for k in keywords)
def retrieve_context(query, n_results=8):
query_embedding = embedder.encode([query]).tolist()
results = collection.query(
query_embeddings=query_embedding,
n_results=n_results,
include=["documents", "metadatas", "distances"]
)
query_embedding = embedder.encode([query]).tolist()[0]
context_pieces = []
sources = []
if is_professional_query(query):
@@ -240,15 +242,24 @@ def retrieve_context(query, n_results=8):
for doc, meta in zip(cv_docs, cv_metas):
context_pieces.append(f"[CV] {doc}")
sources.append(meta.get("source", "CV"))
for doc, meta, dist in zip(
results["documents"][0],
results["metadatas"][0],
results["distances"][0]
):
relevance = 1 - dist
if relevance > 0.3 and meta.get("source") not in CV_SOURCES:
context_pieces.append(doc)
sources.append(meta.get("source", "unknown"))
try:
pg = get_pg()
cur = pg.cursor()
cur.execute("""
SELECT document, source, 1 - (embedding <=> %s::vector) as similarity
FROM embeddings
WHERE source NOT IN %s
ORDER BY embedding <=> %s::vector
LIMIT %s
""", (query_embedding, tuple(CV_SOURCES) if CV_SOURCES else ('__none__',),
query_embedding, n_results))
for doc, source, similarity in cur.fetchall():
if similarity > 0.3:
context_pieces.append(doc)
sources.append(source or "unknown")
pg.close()
except Exception as e:
print(f"pgvector retrieval error: {e}")
return context_pieces, sources
def get_conversation_history(conversation_id, limit=20):
@@ -519,7 +530,14 @@ async def update_memory(request: Request, auth: str = Depends(require_auth)):
@app.get("/api/status")
async def get_status(auth: str = Depends(require_auth)):
chunk_count = collection.count()
try:
pg = get_pg()
cur = pg.cursor()
cur.execute("SELECT COUNT(*) FROM embeddings")
chunk_count = cur.fetchone()[0]
pg.close()
except:
chunk_count = 0
# Watcher status
watcher_running = False