api.py: hybrid retrieval with intent routing and cross-encoder rerank
Replaces pure-dense top-8 retrieval with a three-stage pipeline:
- BM25 (tsvector + websearch_to_tsquery) and dense (pgvector) in parallel,
fused with Reciprocal Rank Fusion
- Optional type filter driven by classify_retrieval_intent() so questions
about prior conversations don't pull documents and vice versa
- Cross-encoder rerank (ms-marco-MiniLM-L-6-v2) over RRF candidates before
taking final top-N
Also adds scripts/reindex_docx_pptx.py — one-off re-ingest used to recover
table/header/text-box content in docx and pptx after the 93c0d89 extractor
upgrade — and scripts/test_retrieval.py to exercise the new pipeline against
representative queries.
Schema: requires GIN index on to_tsvector('english', document) (already
created out-of-band via psql since Apache AGE in shared_preload_libraries
blocks ALTER TABLE on this database).
This commit is contained in:
+128
-15
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import sqlite3
|
||||
import subprocess
|
||||
@@ -6,7 +7,7 @@ import hashlib
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
from dotenv import load_dotenv
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from sentence_transformers import SentenceTransformer, CrossEncoder
|
||||
import anthropic
|
||||
from fastapi import FastAPI, Request, Response, Depends, HTTPException, BackgroundTasks
|
||||
import psycopg2
|
||||
@@ -91,6 +92,7 @@ if HAS_WHISPER:
|
||||
except Exception as e:
|
||||
print(f"Whisper not available: {e}")
|
||||
embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
|
||||
# ChromaDB removed — using pgvector
|
||||
anthropic_client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
|
||||
|
||||
@@ -243,30 +245,140 @@ def remove_from_memory(item):
|
||||
save_memory("\n".join(filtered))
|
||||
return len(lines) - len(filtered)
|
||||
|
||||
def retrieve_context(query, n_results=8):
|
||||
"""Pure semantic retrieval over pgvector. Top-N by cosine similarity, threshold 0.3.
|
||||
No CV pinning, no keyword routing — see architecture doc substrate-dependency section.
|
||||
Substrate-level workarounds (entity-keyed routing, hybrid retrieval) live at the
|
||||
Graphiti layer, not as wrapper logic above pgvector."""
|
||||
HYBRID_CANDIDATES = 30
|
||||
RRF_K = 60
|
||||
FINAL_LIMIT = 8
|
||||
|
||||
_TSQUERY_SANITIZE_RE = re.compile(r"[^\w\s\"'-]")
|
||||
|
||||
CONVERSATION_TYPES = ["chatgpt_conversation", "claude_conversation", "aaronai_conversation"]
|
||||
DOCUMENT_TYPES = ["document"]
|
||||
MEMORY_TYPES = ["claude_memory"]
|
||||
|
||||
_CONVO_SIGNALS = (
|
||||
"what did i tell", "what did we discuss", "what did we talk",
|
||||
"in our conversation", "you mentioned", "we talked about",
|
||||
"earlier you said", "earlier i said", "did i tell you",
|
||||
"did i say", "what did chatgpt", "what did claude",
|
||||
)
|
||||
_DOC_SIGNALS = (
|
||||
"write me a bio", "draft a bio", "my bio", "my cv", "my resume",
|
||||
"my professional", "my work history", "my exhibitions",
|
||||
"my publications", "my syllabi", "my courses", "my teaching",
|
||||
"my philosophy", "about my career", "draft a cover letter",
|
||||
"draft my", "write a bio", "professional bio",
|
||||
)
|
||||
|
||||
|
||||
def _websearch_query(text: str) -> str:
|
||||
"""Strip characters websearch_to_tsquery doesn't handle cleanly. Quoted
|
||||
phrases and 'or' are preserved by the function itself."""
|
||||
return _TSQUERY_SANITIZE_RE.sub(" ", text).strip()
|
||||
|
||||
|
||||
def classify_retrieval_intent(query: str):
|
||||
"""Return a list of `type` values to filter retrieval on, or None for all types.
|
||||
|
||||
Implementation is a low-effort keyword classifier — explicitly tunable and
|
||||
swappable. For more nuanced routing, replace this with an LLM classifier call
|
||||
that returns the same shape: a list of valid type strings or None.
|
||||
|
||||
Precedence: conversation signals win over document signals — a question like
|
||||
"what did I tell you about my CV" is asking about the conversation, not the CV.
|
||||
"""
|
||||
q = query.lower()
|
||||
if any(s in q for s in _CONVO_SIGNALS):
|
||||
return CONVERSATION_TYPES
|
||||
if any(s in q for s in _DOC_SIGNALS):
|
||||
return DOCUMENT_TYPES
|
||||
return None
|
||||
|
||||
|
||||
def _rerank(query: str, candidates: list[tuple]) -> list[tuple]:
|
||||
"""Cross-encoder rerank. Candidates are (id, document, source) tuples.
|
||||
Returns the same tuples reordered by reranker score (highest first)."""
|
||||
if not candidates:
|
||||
return []
|
||||
pairs = [(query, row[1]) for row in candidates]
|
||||
scores = reranker.predict(pairs)
|
||||
return [row for row, _ in sorted(zip(candidates, scores),
|
||||
key=lambda x: x[1], reverse=True)]
|
||||
|
||||
|
||||
def retrieve_context(query, n_results=FINAL_LIMIT, type_filter=None):
|
||||
"""Hybrid retrieval (dense + lexical, RRF fused) followed by cross-encoder rerank.
|
||||
|
||||
- Dense (pgvector) handles paraphrase / semantic similarity.
|
||||
- Lexical (tsvector) catches rare named tokens (FWN3D, Sono-Tek, course codes)
|
||||
the embedding model has no signal for.
|
||||
- RRF combines the two rankings without calibrating score scales.
|
||||
- Cross-encoder rerank scores each (query, chunk) pair jointly, bridging
|
||||
semantic gaps that bi-encoders can't (e.g., "write me a bio" -> CV chunk).
|
||||
|
||||
type_filter: optional list of `type` values to restrict the candidate pool to.
|
||||
If None, retrieves from all types. Use classify_retrieval_intent() to derive."""
|
||||
query_embedding = embedder.encode([query]).tolist()[0]
|
||||
ts_query = _websearch_query(query)
|
||||
|
||||
context_pieces = []
|
||||
sources = []
|
||||
|
||||
where_sql = ""
|
||||
type_param = ()
|
||||
if type_filter:
|
||||
where_sql = "WHERE type = ANY(%s)"
|
||||
type_param = (list(type_filter),)
|
||||
|
||||
try:
|
||||
pg = get_pg()
|
||||
cur = pg.cursor()
|
||||
cur.execute("""
|
||||
SELECT document, source, 1 - (embedding <=> %s::vector) as similarity
|
||||
|
||||
cur.execute(f"""
|
||||
SELECT id, document, source
|
||||
FROM embeddings
|
||||
{where_sql}
|
||||
ORDER BY embedding <=> %s::vector
|
||||
LIMIT %s
|
||||
""", (query_embedding, query_embedding, n_results))
|
||||
for doc, source, similarity in cur.fetchall():
|
||||
if similarity > 0.3:
|
||||
context_pieces.append(doc)
|
||||
sources.append(source or "unknown")
|
||||
""", (*type_param, query_embedding, HYBRID_CANDIDATES))
|
||||
dense_hits = cur.fetchall()
|
||||
|
||||
lexical_hits = []
|
||||
if ts_query:
|
||||
lex_where = "to_tsvector('english', document) @@ websearch_to_tsquery('english', %s)"
|
||||
full_where = (f"WHERE {lex_where} AND type = ANY(%s)"
|
||||
if type_filter else f"WHERE {lex_where}")
|
||||
lex_params = ((ts_query, list(type_filter)) if type_filter else (ts_query,))
|
||||
cur.execute(f"""
|
||||
SELECT id, document, source
|
||||
FROM embeddings
|
||||
{full_where}
|
||||
ORDER BY ts_rank(to_tsvector('english', document),
|
||||
websearch_to_tsquery('english', %s)) DESC
|
||||
LIMIT %s
|
||||
""", (*lex_params, ts_query, HYBRID_CANDIDATES))
|
||||
lexical_hits = cur.fetchall()
|
||||
|
||||
pg.close()
|
||||
|
||||
scores = {}
|
||||
rows_by_id = {}
|
||||
for rank, row in enumerate(dense_hits):
|
||||
scores[row[0]] = scores.get(row[0], 0) + 1.0 / (RRF_K + rank + 1)
|
||||
rows_by_id[row[0]] = row
|
||||
for rank, row in enumerate(lexical_hits):
|
||||
scores[row[0]] = scores.get(row[0], 0) + 1.0 / (RRF_K + rank + 1)
|
||||
rows_by_id[row[0]] = row
|
||||
|
||||
rrf_ranked = sorted(scores.items(), key=lambda kv: kv[1], reverse=True)
|
||||
candidates = [rows_by_id[doc_id] for doc_id, _ in rrf_ranked]
|
||||
|
||||
for _id, doc, source in _rerank(query, candidates)[:n_results]:
|
||||
context_pieces.append(doc)
|
||||
sources.append(source or "unknown")
|
||||
|
||||
except Exception as e:
|
||||
print(f"pgvector retrieval error: {e}")
|
||||
print(f"hybrid retrieval error: {e}")
|
||||
|
||||
return context_pieces, sources
|
||||
|
||||
def get_conversation_history(conversation_id, limit=20):
|
||||
@@ -306,7 +418,8 @@ def create_conversation(title="New conversation"):
|
||||
|
||||
def chat(user_message, conversation_id, settings, client_time=None):
|
||||
memory = load_memory()
|
||||
context_pieces, sources = retrieve_context(user_message)
|
||||
type_filter = classify_retrieval_intent(user_message)
|
||||
context_pieces, sources = retrieve_context(user_message, type_filter=type_filter)
|
||||
history = get_conversation_history(conversation_id)
|
||||
|
||||
context_parts = []
|
||||
|
||||
Reference in New Issue
Block a user