Migrate to pgvector — remove ChromaDB from api.py, ingest scripts, dream.py
This commit is contained in:
+44
-26
@@ -6,10 +6,11 @@ import hashlib
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from dotenv import load_dotenv
|
||||
import chromadb
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import anthropic
|
||||
from fastapi import FastAPI, Request, Response, Depends, HTTPException
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
from fastapi import UploadFile, File
|
||||
import tempfile
|
||||
import os
|
||||
@@ -46,6 +47,10 @@ DEFAULT_SETTINGS = {
|
||||
}
|
||||
|
||||
print("Loading Aaron AI...")
|
||||
PG_DSN = os.getenv("PG_DSN", "dbname=aaronai user=aaronai password=aaronai_db_password host=localhost")
|
||||
|
||||
def get_pg():
|
||||
return psycopg2.connect(PG_DSN)
|
||||
WHISPER_PROMPT = (
|
||||
"Grasshopper, Rhino, PolyJet, SLA, FDM, DMLS, ChromaDB, "
|
||||
"HVAMC, FWN3D, Mossygear, Nextcloud, Gitea, computational design, "
|
||||
@@ -59,11 +64,7 @@ if HAS_WHISPER:
|
||||
except Exception as e:
|
||||
print(f"Whisper not available: {e}")
|
||||
embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
chroma_client = chromadb.PersistentClient(path=DB_PATH)
|
||||
collection = chroma_client.get_or_create_collection(
|
||||
name="aaronai",
|
||||
metadata={"hnsw:space": "cosine", "hnsw:allow_replace_deleted": True}
|
||||
)
|
||||
# ChromaDB removed — using pgvector
|
||||
anthropic_client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
|
||||
|
||||
SYSTEM_PROMPT = """You are the personal AI assistant of Aaron Nelson — computational
|
||||
@@ -210,11 +211,17 @@ def remove_from_memory(item):
|
||||
|
||||
def get_pinned_cv_context():
|
||||
try:
|
||||
results = collection.get(
|
||||
where={"source": {"$in": CV_SOURCES}},
|
||||
include=["documents", "metadatas"]
|
||||
pg = get_pg()
|
||||
cur = pg.cursor()
|
||||
cur.execute(
|
||||
"SELECT document, source FROM embeddings WHERE source = ANY(%s)",
|
||||
(CV_SOURCES,)
|
||||
)
|
||||
return results["documents"], results["metadatas"]
|
||||
rows = cur.fetchall()
|
||||
pg.close()
|
||||
docs = [r[0] for r in rows]
|
||||
metas = [{"source": r[1]} for r in rows]
|
||||
return docs, metas
|
||||
except:
|
||||
return [], []
|
||||
|
||||
@@ -227,12 +234,7 @@ def is_professional_query(query):
|
||||
return any(k in query.lower() for k in keywords)
|
||||
|
||||
def retrieve_context(query, n_results=8):
|
||||
query_embedding = embedder.encode([query]).tolist()
|
||||
results = collection.query(
|
||||
query_embeddings=query_embedding,
|
||||
n_results=n_results,
|
||||
include=["documents", "metadatas", "distances"]
|
||||
)
|
||||
query_embedding = embedder.encode([query]).tolist()[0]
|
||||
context_pieces = []
|
||||
sources = []
|
||||
if is_professional_query(query):
|
||||
@@ -240,15 +242,24 @@ def retrieve_context(query, n_results=8):
|
||||
for doc, meta in zip(cv_docs, cv_metas):
|
||||
context_pieces.append(f"[CV] {doc}")
|
||||
sources.append(meta.get("source", "CV"))
|
||||
for doc, meta, dist in zip(
|
||||
results["documents"][0],
|
||||
results["metadatas"][0],
|
||||
results["distances"][0]
|
||||
):
|
||||
relevance = 1 - dist
|
||||
if relevance > 0.3 and meta.get("source") not in CV_SOURCES:
|
||||
context_pieces.append(doc)
|
||||
sources.append(meta.get("source", "unknown"))
|
||||
try:
|
||||
pg = get_pg()
|
||||
cur = pg.cursor()
|
||||
cur.execute("""
|
||||
SELECT document, source, 1 - (embedding <=> %s::vector) as similarity
|
||||
FROM embeddings
|
||||
WHERE source NOT IN %s
|
||||
ORDER BY embedding <=> %s::vector
|
||||
LIMIT %s
|
||||
""", (query_embedding, tuple(CV_SOURCES) if CV_SOURCES else ('__none__',),
|
||||
query_embedding, n_results))
|
||||
for doc, source, similarity in cur.fetchall():
|
||||
if similarity > 0.3:
|
||||
context_pieces.append(doc)
|
||||
sources.append(source or "unknown")
|
||||
pg.close()
|
||||
except Exception as e:
|
||||
print(f"pgvector retrieval error: {e}")
|
||||
return context_pieces, sources
|
||||
|
||||
def get_conversation_history(conversation_id, limit=20):
|
||||
@@ -519,7 +530,14 @@ async def update_memory(request: Request, auth: str = Depends(require_auth)):
|
||||
|
||||
@app.get("/api/status")
|
||||
async def get_status(auth: str = Depends(require_auth)):
|
||||
chunk_count = collection.count()
|
||||
try:
|
||||
pg = get_pg()
|
||||
cur = pg.cursor()
|
||||
cur.execute("SELECT COUNT(*) FROM embeddings")
|
||||
chunk_count = cur.fetchone()[0]
|
||||
pg.close()
|
||||
except:
|
||||
chunk_count = 0
|
||||
|
||||
# Watcher status
|
||||
watcher_running = False
|
||||
|
||||
Reference in New Issue
Block a user