diff --git a/scripts/api.py b/scripts/api.py index b1bbf33..3fac813 100644 --- a/scripts/api.py +++ b/scripts/api.py @@ -277,25 +277,32 @@ def _websearch_query(text: str) -> str: def classify_retrieval_intent(query: str): - """Return a list of `type` values to filter retrieval on, or None for all types. + """Return (type_filter, folder_exclude_prefixes). Either may be None. + + type_filter restricts the candidate pool by `type`; folder_exclude_prefixes + excludes any chunk whose metadata.folder matches a LIKE 'prefix%' pattern. Implementation is a low-effort keyword classifier — explicitly tunable and - swappable. For more nuanced routing, replace this with an LLM classifier call - that returns the same shape: a list of valid type strings or None. + swappable. For nuanced routing, replace with an LLM classifier returning + the same shape. - Precedence: conversation signals win over document signals — a question like - "what did I tell you about my CV" is asking about the conversation, not the CV. - """ + Precedence: conversation signals win over document signals — "what did I + tell you about my CV" is asking about the conversation, not the CV. + + For biographical/document intent, also exclude the reference library + (Library/Foundations/* — philosophy and cognition books), which is + categorically different from personal artifacts but lives in the same + `type='document'` bucket.""" q = query.lower() if any(s in q for s in _CONVO_SIGNALS): - return CONVERSATION_TYPES + return (CONVERSATION_TYPES, None) if any(s in q for s in _DOC_SIGNALS): - return DOCUMENT_TYPES - return None + return (DOCUMENT_TYPES, ["Library/"]) + return (None, None) def _rerank(query: str, candidates: list[tuple]) -> list[tuple]: - """Cross-encoder rerank. Candidates are (id, document, source) tuples. + """Cross-encoder rerank. Candidates are (id, document, source, folder) tuples. Returns the same tuples reordered by reranker score (highest first).""" if not candidates: return [] @@ -305,7 +312,25 @@ def _rerank(query: str, candidates: list[tuple]) -> list[tuple]: key=lambda x: x[1], reverse=True)] -def retrieve_context(query, n_results=FINAL_LIMIT, type_filter=None): +def _format_source(source: str, folder: str) -> str: + """Surface folder context to the LLM so it can disambiguate same-named files + (e.g., 21 different CV.docx files across job-application folders).""" + source = source or "unknown" + if folder and folder not in ("", "."): + return f"{folder}/{source}" + return source + + +def _dedup_key(doc: str) -> str: + """Collapse near-duplicates by content. Files copied to multiple folders + produce byte-identical chunks; this catches those without affecting + legitimately-different chunks of the same source (e.g., separate sections + of a conversation).""" + return hashlib.md5(doc[:300].lower().encode("utf-8", "ignore")).hexdigest() + + +def retrieve_context(query, n_results=FINAL_LIMIT, + type_filter=None, folder_exclude_prefixes=None): """Hybrid retrieval (dense + lexical, RRF fused) followed by cross-encoder rerank. - Dense (pgvector) handles paraphrase / semantic similarity. @@ -314,48 +339,61 @@ def retrieve_context(query, n_results=FINAL_LIMIT, type_filter=None): - RRF combines the two rankings without calibrating score scales. - Cross-encoder rerank scores each (query, chunk) pair jointly, bridging semantic gaps that bi-encoders can't (e.g., "write me a bio" -> CV chunk). + - Near-duplicate collapse on output so top-N slots aren't burned by + multi-folder copies of the same file. type_filter: optional list of `type` values to restrict the candidate pool to. - If None, retrieves from all types. Use classify_retrieval_intent() to derive.""" + folder_exclude_prefixes: optional list of folder LIKE prefixes to exclude. + Both default to None (no restriction). Use classify_retrieval_intent() to derive.""" query_embedding = embedder.encode([query]).tolist()[0] ts_query = _websearch_query(query) context_pieces = [] sources = [] - where_sql = "" - type_param = () + where_clauses = [] + extra_params = [] if type_filter: - where_sql = "WHERE type = ANY(%s)" - type_param = (list(type_filter),) + where_clauses.append("type = ANY(%s)") + extra_params.append(list(type_filter)) + for prefix in (folder_exclude_prefixes or []): + where_clauses.append("(metadata->>'folder' IS NULL OR metadata->>'folder' NOT LIKE %s)") + extra_params.append(prefix + "%") + + common_where = ("WHERE " + " AND ".join(where_clauses)) if where_clauses else "" try: pg = get_pg() cur = pg.cursor() + # pgvector 0.6 HNSW doesn't iterate past its initial candidate list when + # a restrictive WHERE filter is present — so a filter that excludes the + # top-N nearest leaves nothing. Bumping ef_search forces the index to + # explore more graph nodes. Cheap when unfiltered; load-bearing when filtered. + if where_clauses: + cur.execute("SET LOCAL hnsw.ef_search = 500") + cur.execute(f""" - SELECT id, document, source + SELECT id, document, source, metadata->>'folder' AS folder FROM embeddings - {where_sql} + {common_where} ORDER BY embedding <=> %s::vector LIMIT %s - """, (*type_param, query_embedding, HYBRID_CANDIDATES)) + """, (*extra_params, query_embedding, HYBRID_CANDIDATES)) dense_hits = cur.fetchall() lexical_hits = [] if ts_query: - lex_where = "to_tsvector('english', document) @@ websearch_to_tsquery('english', %s)" - full_where = (f"WHERE {lex_where} AND type = ANY(%s)" - if type_filter else f"WHERE {lex_where}") - lex_params = ((ts_query, list(type_filter)) if type_filter else (ts_query,)) + lex_match = "to_tsvector('english', document) @@ websearch_to_tsquery('english', %s)" + lex_where = ("WHERE " + " AND ".join([lex_match] + where_clauses)) cur.execute(f""" - SELECT id, document, source + SELECT id, document, source, metadata->>'folder' AS folder FROM embeddings - {full_where} + {lex_where} ORDER BY ts_rank(to_tsvector('english', document), websearch_to_tsquery('english', %s)) DESC LIMIT %s - """, (*lex_params, ts_query, HYBRID_CANDIDATES)) + """, (ts_query, *extra_params, ts_query, HYBRID_CANDIDATES)) lexical_hits = cur.fetchall() pg.close() @@ -372,9 +410,16 @@ def retrieve_context(query, n_results=FINAL_LIMIT, type_filter=None): rrf_ranked = sorted(scores.items(), key=lambda kv: kv[1], reverse=True) candidates = [rows_by_id[doc_id] for doc_id, _ in rrf_ranked] - for _id, doc, source in _rerank(query, candidates)[:n_results]: + seen = set() + for _id, doc, source, folder in _rerank(query, candidates): + key = _dedup_key(doc) + if key in seen: + continue + seen.add(key) context_pieces.append(doc) - sources.append(source or "unknown") + sources.append(_format_source(source, folder)) + if len(context_pieces) >= n_results: + break except Exception as e: print(f"hybrid retrieval error: {e}") @@ -418,8 +463,11 @@ def create_conversation(title="New conversation"): def chat(user_message, conversation_id, settings, client_time=None): memory = load_memory() - type_filter = classify_retrieval_intent(user_message) - context_pieces, sources = retrieve_context(user_message, type_filter=type_filter) + type_filter, folder_excludes = classify_retrieval_intent(user_message) + context_pieces, sources = retrieve_context( + user_message, type_filter=type_filter, + folder_exclude_prefixes=folder_excludes, + ) history = get_conversation_history(conversation_id) context_parts = [] diff --git a/scripts/test_retrieval.py b/scripts/test_retrieval.py index 516a3b1..88ffd75 100644 --- a/scripts/test_retrieval.py +++ b/scripts/test_retrieval.py @@ -50,9 +50,11 @@ QUERIES = [ ] for q in QUERIES: - intent = classify_retrieval_intent(q) - pieces, sources = retrieve_context(q, type_filter=intent) + type_filter, folder_excludes = classify_retrieval_intent(q) + pieces, sources = retrieve_context( + q, type_filter=type_filter, folder_exclude_prefixes=folder_excludes, + ) print(f"\n=== {q!r} ===") - print(f" intent: {intent}") + print(f" type_filter: {type_filter} folder_excludes: {folder_excludes}") for i, src in enumerate(sources, 1): print(f" {i}. {src}")