From 8d560f9f5e213ccc95685480bc2b6dcc682c4718 Mon Sep 17 00:00:00 2001 From: Aaron Nelson Date: Tue, 19 May 2026 21:11:15 +0000 Subject: [PATCH] api.py: hybrid retrieval with intent routing and cross-encoder rerank MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces pure-dense top-8 retrieval with a three-stage pipeline: - BM25 (tsvector + websearch_to_tsquery) and dense (pgvector) in parallel, fused with Reciprocal Rank Fusion - Optional type filter driven by classify_retrieval_intent() so questions about prior conversations don't pull documents and vice versa - Cross-encoder rerank (ms-marco-MiniLM-L-6-v2) over RRF candidates before taking final top-N Also adds scripts/reindex_docx_pptx.py — one-off re-ingest used to recover table/header/text-box content in docx and pptx after the 93c0d89 extractor upgrade — and scripts/test_retrieval.py to exercise the new pipeline against representative queries. Schema: requires GIN index on to_tsvector('english', document) (already created out-of-band via psql since Apache AGE in shared_preload_libraries blocks ALTER TABLE on this database). --- .gitignore | 1 + scripts/api.py | 143 +++++++++++++++++++++++++++++++---- scripts/reindex_docx_pptx.py | 135 +++++++++++++++++++++++++++++++++ scripts/test_retrieval.py | 58 ++++++++++++++ 4 files changed, 322 insertions(+), 15 deletions(-) create mode 100644 scripts/reindex_docx_pptx.py create mode 100644 scripts/test_retrieval.py diff --git a/.gitignore b/.gitignore index eef432d..28d7908 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ dreamer_state.json corpus_integrity_report.json watcher_state.json watcher_status.json +reindex_status.json # Logs (these belong in /var/log/) *.log diff --git a/scripts/api.py b/scripts/api.py index 3143dae..b1bbf33 100644 --- a/scripts/api.py +++ b/scripts/api.py @@ -1,4 +1,5 @@ import os +import re import json import sqlite3 import subprocess @@ -6,7 +7,7 @@ import hashlib from pathlib import Path from datetime import datetime, timedelta from dotenv import load_dotenv -from sentence_transformers import SentenceTransformer +from sentence_transformers import SentenceTransformer, CrossEncoder import anthropic from fastapi import FastAPI, Request, Response, Depends, HTTPException, BackgroundTasks import psycopg2 @@ -91,6 +92,7 @@ if HAS_WHISPER: except Exception as e: print(f"Whisper not available: {e}") embedder = SentenceTransformer("all-MiniLM-L6-v2") +reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") # ChromaDB removed — using pgvector anthropic_client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) @@ -243,30 +245,140 @@ def remove_from_memory(item): save_memory("\n".join(filtered)) return len(lines) - len(filtered) -def retrieve_context(query, n_results=8): - """Pure semantic retrieval over pgvector. Top-N by cosine similarity, threshold 0.3. - No CV pinning, no keyword routing — see architecture doc substrate-dependency section. - Substrate-level workarounds (entity-keyed routing, hybrid retrieval) live at the - Graphiti layer, not as wrapper logic above pgvector.""" +HYBRID_CANDIDATES = 30 +RRF_K = 60 +FINAL_LIMIT = 8 + +_TSQUERY_SANITIZE_RE = re.compile(r"[^\w\s\"'-]") + +CONVERSATION_TYPES = ["chatgpt_conversation", "claude_conversation", "aaronai_conversation"] +DOCUMENT_TYPES = ["document"] +MEMORY_TYPES = ["claude_memory"] + +_CONVO_SIGNALS = ( + "what did i tell", "what did we discuss", "what did we talk", + "in our conversation", "you mentioned", "we talked about", + "earlier you said", "earlier i said", "did i tell you", + "did i say", "what did chatgpt", "what did claude", +) +_DOC_SIGNALS = ( + "write me a bio", "draft a bio", "my bio", "my cv", "my resume", + "my professional", "my work history", "my exhibitions", + "my publications", "my syllabi", "my courses", "my teaching", + "my philosophy", "about my career", "draft a cover letter", + "draft my", "write a bio", "professional bio", +) + + +def _websearch_query(text: str) -> str: + """Strip characters websearch_to_tsquery doesn't handle cleanly. Quoted + phrases and 'or' are preserved by the function itself.""" + return _TSQUERY_SANITIZE_RE.sub(" ", text).strip() + + +def classify_retrieval_intent(query: str): + """Return a list of `type` values to filter retrieval on, or None for all types. + + Implementation is a low-effort keyword classifier — explicitly tunable and + swappable. For more nuanced routing, replace this with an LLM classifier call + that returns the same shape: a list of valid type strings or None. + + Precedence: conversation signals win over document signals — a question like + "what did I tell you about my CV" is asking about the conversation, not the CV. + """ + q = query.lower() + if any(s in q for s in _CONVO_SIGNALS): + return CONVERSATION_TYPES + if any(s in q for s in _DOC_SIGNALS): + return DOCUMENT_TYPES + return None + + +def _rerank(query: str, candidates: list[tuple]) -> list[tuple]: + """Cross-encoder rerank. Candidates are (id, document, source) tuples. + Returns the same tuples reordered by reranker score (highest first).""" + if not candidates: + return [] + pairs = [(query, row[1]) for row in candidates] + scores = reranker.predict(pairs) + return [row for row, _ in sorted(zip(candidates, scores), + key=lambda x: x[1], reverse=True)] + + +def retrieve_context(query, n_results=FINAL_LIMIT, type_filter=None): + """Hybrid retrieval (dense + lexical, RRF fused) followed by cross-encoder rerank. + + - Dense (pgvector) handles paraphrase / semantic similarity. + - Lexical (tsvector) catches rare named tokens (FWN3D, Sono-Tek, course codes) + the embedding model has no signal for. + - RRF combines the two rankings without calibrating score scales. + - Cross-encoder rerank scores each (query, chunk) pair jointly, bridging + semantic gaps that bi-encoders can't (e.g., "write me a bio" -> CV chunk). + + type_filter: optional list of `type` values to restrict the candidate pool to. + If None, retrieves from all types. Use classify_retrieval_intent() to derive.""" query_embedding = embedder.encode([query]).tolist()[0] + ts_query = _websearch_query(query) + context_pieces = [] sources = [] + + where_sql = "" + type_param = () + if type_filter: + where_sql = "WHERE type = ANY(%s)" + type_param = (list(type_filter),) + try: pg = get_pg() cur = pg.cursor() - cur.execute(""" - SELECT document, source, 1 - (embedding <=> %s::vector) as similarity + + cur.execute(f""" + SELECT id, document, source FROM embeddings + {where_sql} ORDER BY embedding <=> %s::vector LIMIT %s - """, (query_embedding, query_embedding, n_results)) - for doc, source, similarity in cur.fetchall(): - if similarity > 0.3: - context_pieces.append(doc) - sources.append(source or "unknown") + """, (*type_param, query_embedding, HYBRID_CANDIDATES)) + dense_hits = cur.fetchall() + + lexical_hits = [] + if ts_query: + lex_where = "to_tsvector('english', document) @@ websearch_to_tsquery('english', %s)" + full_where = (f"WHERE {lex_where} AND type = ANY(%s)" + if type_filter else f"WHERE {lex_where}") + lex_params = ((ts_query, list(type_filter)) if type_filter else (ts_query,)) + cur.execute(f""" + SELECT id, document, source + FROM embeddings + {full_where} + ORDER BY ts_rank(to_tsvector('english', document), + websearch_to_tsquery('english', %s)) DESC + LIMIT %s + """, (*lex_params, ts_query, HYBRID_CANDIDATES)) + lexical_hits = cur.fetchall() + pg.close() + + scores = {} + rows_by_id = {} + for rank, row in enumerate(dense_hits): + scores[row[0]] = scores.get(row[0], 0) + 1.0 / (RRF_K + rank + 1) + rows_by_id[row[0]] = row + for rank, row in enumerate(lexical_hits): + scores[row[0]] = scores.get(row[0], 0) + 1.0 / (RRF_K + rank + 1) + rows_by_id[row[0]] = row + + rrf_ranked = sorted(scores.items(), key=lambda kv: kv[1], reverse=True) + candidates = [rows_by_id[doc_id] for doc_id, _ in rrf_ranked] + + for _id, doc, source in _rerank(query, candidates)[:n_results]: + context_pieces.append(doc) + sources.append(source or "unknown") + except Exception as e: - print(f"pgvector retrieval error: {e}") + print(f"hybrid retrieval error: {e}") + return context_pieces, sources def get_conversation_history(conversation_id, limit=20): @@ -306,7 +418,8 @@ def create_conversation(title="New conversation"): def chat(user_message, conversation_id, settings, client_time=None): memory = load_memory() - context_pieces, sources = retrieve_context(user_message) + type_filter = classify_retrieval_intent(user_message) + context_pieces, sources = retrieve_context(user_message, type_filter=type_filter) history = get_conversation_history(conversation_id) context_parts = [] diff --git a/scripts/reindex_docx_pptx.py b/scripts/reindex_docx_pptx.py new file mode 100644 index 0000000..d19b3ab --- /dev/null +++ b/scripts/reindex_docx_pptx.py @@ -0,0 +1,135 @@ +"""One-off: re-ingest docx+pptx after the 2026-05-04 extractor upgrade (commit 93c0d89). + +Pre-upgrade extraction missed tables, headers/footers, text boxes, group shapes, +and pptx notes — leaving CVs/dossiers as section-header skeletons in the index. + +Steps when run with --apply: + 1. DELETE all embeddings rows where source ends in .docx or .pptx + 2. Walk NEXTCLOUD_PATH and re-ingest every .docx/.pptx via _ingest_one + 3. Stage 2 enqueue is suppressed (SKIP_STAGE2_ENQUEUE=1) + +Without --apply: dry-run. Counts files and chunks, prints a sample, writes nothing. +""" + +import os +import sys +import time +from pathlib import Path + +os.environ["SKIP_STAGE2_ENQUEUE"] = "1" + +from dotenv import load_dotenv +load_dotenv(Path.home() / "aaronai" / ".env", override=True) + +import psycopg2 +from sentence_transformers import SentenceTransformer + +sys.path.insert(0, str(Path(__file__).parent)) +from ingest import _ingest_one, get_pg + +NEXTCLOUD_PATH = Path("/home/aaron/nextcloud/data/data/aaron/files") +TARGET_EXTS = {".docx", ".pptx"} + +APPLY = "--apply" in sys.argv + + +def count_stale(): + pg = get_pg() + cur = pg.cursor() + cur.execute( + "SELECT lower(substring(source from '\\.[^.]+$')) AS ext, " + "COUNT(DISTINCT source) AS files, COUNT(*) AS chunks " + "FROM embeddings WHERE lower(source) ~ '\\.(docx|pptx)$' " + "GROUP BY 1 ORDER BY 1" + ) + rows = cur.fetchall() + pg.close() + return rows + + +def delete_stale(): + pg = get_pg() + cur = pg.cursor() + cur.execute("DELETE FROM embeddings WHERE lower(source) ~ '\\.(docx|pptx)$'") + deleted = cur.rowcount + pg.commit() + pg.close() + return deleted + + +def find_files(): + files = [] + for f in NEXTCLOUD_PATH.rglob("*"): + if not f.is_file(): + continue + if f.suffix.lower() not in TARGET_EXTS: + continue + if f.name.startswith(("~$", ".")): + continue + files.append(f) + return files + + +def main(): + print(f"Mode: {'APPLY (destructive)' if APPLY else 'DRY-RUN (no writes)'}") + print(f"Target: {NEXTCLOUD_PATH}") + print(f"Extensions: {sorted(TARGET_EXTS)}") + print(f"SKIP_STAGE2_ENQUEUE={os.environ.get('SKIP_STAGE2_ENQUEUE')}") + print() + + print("Stale chunks currently in DB:") + for ext, files, chunks in count_stale(): + print(f" {ext}: {files} files, {chunks} chunks") + print() + + files = find_files() + by_ext = {} + for f in files: + by_ext.setdefault(f.suffix.lower(), []).append(f) + print(f"Files on disk to re-ingest:") + for ext, lst in sorted(by_ext.items()): + print(f" {ext}: {len(lst)} files") + print(f" total: {len(files)}") + print() + print("Sample (5 random):") + import random + for f in random.sample(files, min(5, len(files))): + print(f" {f}") + print() + + if not APPLY: + print("Dry-run only. Re-run with --apply to delete + re-ingest.") + return + + print("Deleting stale chunks...") + n = delete_stale() + print(f" deleted {n} rows") + print() + + print("Loading embedder...") + embedder = SentenceTransformer("all-MiniLM-L6-v2") + print() + + print(f"Re-ingesting {len(files)} files...") + started = time.time() + ingested = failed = total_chunks = 0 + for i, f in enumerate(files, 1): + n = _ingest_one(f, embedder, root=NEXTCLOUD_PATH) + if n > 0: + ingested += 1 + total_chunks += n + else: + failed += 1 + if i % 25 == 0 or i == len(files): + elapsed = time.time() - started + rate = i / elapsed if elapsed else 0 + print(f" [{i}/{len(files)}] ingested={ingested} failed={failed} " + f"chunks={total_chunks} ({rate:.1f} files/s)") + elapsed = time.time() - started + print() + print(f"Done in {elapsed:.0f}s: {ingested} ingested, {failed} failed, " + f"{total_chunks} chunks written.") + + +if __name__ == "__main__": + main() diff --git a/scripts/test_retrieval.py b/scripts/test_retrieval.py new file mode 100644 index 0000000..516a3b1 --- /dev/null +++ b/scripts/test_retrieval.py @@ -0,0 +1,58 @@ +"""End-to-end test of retrieve_context with intent routing + reranking. + +Avoids loading the full FastAPI app; replicates the chat-handler retrieval +call shape and prints classifier output + final ranked sources for each query. +""" + +import os +import sys +from pathlib import Path + +from dotenv import load_dotenv +load_dotenv(Path.home() / "aaronai" / ".env", override=True) + +sys.path.insert(0, str(Path(__file__).parent)) + +# Stub anthropic so api.py import doesn't fail without the SDK loaded. +# We only need retrieve_context + classify_retrieval_intent. +import types +sys.modules.setdefault("anthropic", types.ModuleType("anthropic")) +sys.modules["anthropic"].Anthropic = lambda **kw: None + +# Same for whisper if present +if "faster_whisper" not in sys.modules: + sys.modules["faster_whisper"] = types.ModuleType("faster_whisper") + +import importlib.util +spec = importlib.util.spec_from_file_location("api", Path(__file__).parent / "api.py") +api = importlib.util.module_from_spec(spec) +# Don't execute the whole module (it starts FastAPI). Instead, exec only definitions. +# Easier: just import the functions we need by exec'ing the file but catching errors. +try: + spec.loader.exec_module(api) +except Exception as e: + print(f"(continuing despite api.py side-effect error: {e})") + +retrieve_context = api.retrieve_context +classify_retrieval_intent = api.classify_retrieval_intent + +QUERIES = [ + "write me a bio", + "my professional bio", + "draft a bio for the Utah application", + "Aaron Nelson CV consulting and design work", + "FWN3D consulting", + "syllabi I have taught", + "philosophy of teaching", + "what did I tell Claude about FWN3D", + "what did we discuss about the Utah job", + "Hudson Valley Additive Manufacturing Center", +] + +for q in QUERIES: + intent = classify_retrieval_intent(q) + pieces, sources = retrieve_context(q, type_filter=intent) + print(f"\n=== {q!r} ===") + print(f" intent: {intent}") + for i, src in enumerate(sources, 1): + print(f" {i}. {src}")