diff --git a/graphiti_patches/.gitignore b/graphiti_patches/.gitignore new file mode 100644 index 0000000..624bdda --- /dev/null +++ b/graphiti_patches/.gitignore @@ -0,0 +1,4 @@ +# Local backups created by apply.sh — environment state, not source. +# Keeping these out of version control prevents repo bloat and avoids +# checking in graphiti-core's Apache-2.0 source under our repo's tree. +backups/ diff --git a/graphiti_patches/README.md b/graphiti_patches/README.md new file mode 100644 index 0000000..da760db --- /dev/null +++ b/graphiti_patches/README.md @@ -0,0 +1,58 @@ +# graphiti-core Patches — FalkorDB Vector Index Support + +Vendored patches against graphiti-core 0.29.0 adding native FalkorDB +vector index support. Three files modified, all under +`graphiti_core/driver/falkordb/` and `graphiti_core/graph_queries.py`. +No changes to Neo4j or Kuzu code paths. + +## Why this exists + +graphiti-core's FalkorDB driver uses interpreted Cypher cosine math +(`vec.cosineDistance(...)`) for similarity search. Each query becomes a +full table scan over Entity/RELATES_TO/Community nodes. At ~4,000+ +entities, single-episode ingest's resolve-against-existing-graph step +takes 8+ minutes and bulk ingest hangs FalkorDB. FalkorDB itself +supports `db.idx.vector.queryNodes` and `db.idx.vector.queryRelationships` +procedures backed by HNSW indexes; graphiti-core's driver doesn't use +them. + +These patches: + +1. Add `get_vector_indices()` to `graph_queries.py` returning CREATE + VECTOR INDEX statements for FalkorDB on Entity.name_embedding, + RELATES_TO.fact_embedding, and Community.name_embedding. +2. Extend `falkordb_driver.py:build_indices_and_constraints()` to create + the vector indexes alongside range and fulltext indexes. +3. Rewrite the three vector-similarity call sites in + `falkordb/operations/search_ops.py` to use + `db.idx.vector.queryNodes` and `db.idx.vector.queryRelationships` + instead of full-scan cosine math. Over-fetches by a configurable + multiplier to handle filter rejections. + +## Files + +| Patched file | Source | +|---|---| +| `graphiti_core/graph_queries.py` | Adds `get_vector_indices()` | +| `graphiti_core/driver/falkordb/falkordb_driver.py` | Extends `build_indices_and_constraints` | +| `graphiti_core/driver/falkordb/operations/search_ops.py` | Three query rewrites | + +## How to apply + +`./apply.sh` — backs up the originals into `./backups//` +and copies the patched files over. + +## How to revert + +Move the timestamped backup back over the venv: + + cp backups//graph_queries.py /home/aaron/aaronai/venv/lib/python3.12/site-packages/graphiti_core/graph_queries.py + # ...etc + +## Upstream candidate + +Documented gap (issue #1263 references it indirectly via vector store +overlay RFC). Maintainers' attention is on Milvus/external vector DB +overlay; this patch is the FalkorDB-native alternative for users who +don't want a separate vector DB. Consider PR after empirical validation +in production. diff --git a/graphiti_patches/apply.sh b/graphiti_patches/apply.sh new file mode 100755 index 0000000..986ffca --- /dev/null +++ b/graphiti_patches/apply.sh @@ -0,0 +1,77 @@ +#!/usr/bin/env bash +# apply.sh — Apply the BirdAI vendored graphiti-core patches. +# +# Backs up the original venv files into ./backups// before +# overwriting. The backup directory layout mirrors the venv layout so a +# revert is just a tree copy back. +# +# Usage: ./apply.sh + +set -euo pipefail + +PATCH_DIR="$(cd "$(dirname "$0")" && pwd)" +VENV_BASE="/home/aaron/aaronai/venv/lib/python3.12/site-packages" +TIMESTAMP="$(date +%Y%m%d-%H%M%S)" +BACKUP_DIR="$PATCH_DIR/backups/$TIMESTAMP" + +# Files to patch — paths relative to graphiti_core/. +FILES=( + "graph_queries.py" + "driver/falkordb_driver.py" + "driver/falkordb/operations/search_ops.py" +) + +echo "graphiti-core vendored patch apply — BirdAI" +echo "Patch directory: $PATCH_DIR" +echo "Venv target: $VENV_BASE/graphiti_core/" +echo "Backup to: $BACKUP_DIR" +echo + +# Pre-flight: confirm all source patch files exist. +for rel in "${FILES[@]}"; do + if [ ! -f "$PATCH_DIR/graphiti_core/$rel" ]; then + echo "ERROR: missing patch file: $PATCH_DIR/graphiti_core/$rel" >&2 + exit 1 + fi +done + +# Pre-flight: confirm all target venv files exist. +for rel in "${FILES[@]}"; do + if [ ! -f "$VENV_BASE/graphiti_core/$rel" ]; then + echo "ERROR: missing venv file: $VENV_BASE/graphiti_core/$rel" >&2 + echo " graphiti-core may not be installed, or version differs from 0.29.0." >&2 + exit 1 + fi +done + +# Backup originals. +echo "[1/3] Backing up originals..." +for rel in "${FILES[@]}"; do + backup_path="$BACKUP_DIR/graphiti_core/$rel" + mkdir -p "$(dirname "$backup_path")" + cp "$VENV_BASE/graphiti_core/$rel" "$backup_path" + echo " backed up: $rel" +done +echo + +# Apply patches by copying. +echo "[2/3] Applying patches..." +for rel in "${FILES[@]}"; do + cp "$PATCH_DIR/graphiti_core/$rel" "$VENV_BASE/graphiti_core/$rel" + echo " patched: $rel" +done +echo + +# Sanity check: confirm patched files have the marker. +echo "[3/3] Verifying patched files..." +for rel in "${FILES[@]}"; do + if grep -q "PATCHED 2026-05-02" "$VENV_BASE/graphiti_core/$rel"; then + echo " OK: $rel contains patch marker" + else + echo " WARNING: $rel missing patch marker (may be expected for graph_queries.py — its docstring uses the marker only in the module header)" + fi +done +echo +echo "Done. Backup: $BACKUP_DIR" +echo "Restart the sidecar to pick up changes:" +echo " sudo systemctl restart aaronai-graphiti.service" diff --git a/graphiti_patches/graphiti_core/driver/falkordb/operations/search_ops.py b/graphiti_patches/graphiti_core/driver/falkordb/operations/search_ops.py new file mode 100644 index 0000000..639d1e1 --- /dev/null +++ b/graphiti_patches/graphiti_core/driver/falkordb/operations/search_ops.py @@ -0,0 +1,904 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import logging +from typing import Any + +from graphiti_core.driver.driver import GraphProvider +from graphiti_core.driver.falkordb import STOPWORDS +from graphiti_core.driver.operations.search_ops import SearchOperations +from graphiti_core.driver.query_executor import QueryExecutor +from graphiti_core.driver.record_parsers import ( + community_node_from_record, + entity_edge_from_record, + entity_node_from_record, + episodic_node_from_record, +) +from graphiti_core.edges import EntityEdge +from graphiti_core.graph_queries import ( + get_nodes_query, + get_relationships_query, + get_vector_cosine_func_query, +) +from graphiti_core.models.edges.edge_db_queries import get_entity_edge_return_query +from graphiti_core.models.nodes.node_db_queries import ( + COMMUNITY_NODE_RETURN, + EPISODIC_NODE_RETURN, + get_entity_node_return_query, +) +from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode +from graphiti_core.search.search_filters import ( + SearchFilters, + edge_search_filter_query_constructor, + node_search_filter_query_constructor, +) + +logger = logging.getLogger(__name__) + +MAX_QUERY_LENGTH = 128 + +# --------------------------------------------------------------------------- +# Vector index dispatcher (PATCHED 2026-05-02, BirdAI vendored patch). +# +# graphiti-core's FalkorDB driver historically composed similarity queries +# using `vec.cosineDistance(...)` in interpreted Cypher, which produces a +# full-table scan for every search. FalkorDB supports native vector indexes +# via `db.idx.vector.queryNodes` and `db.idx.vector.queryRelationships`; +# this dispatcher uses them when present and falls back to the cosine math +# otherwise. +# +# Index existence is checked once per (label, attribute, entity_type) and +# cached at module scope. The cache should be invalidated whenever +# `build_indices_and_constraints` runs (since indexes may have been created +# or dropped). FalkorDriver.build_indices_and_constraints is patched to +# call `_invalidate_falkordb_vector_index_cache()` after building. +# +# Over-fetch factor (VECTOR_INDEX_CANDIDATE_MULTIPLIER from graph_queries) +# preserves recall when WHERE filters reject some of the top-k candidates. +# --------------------------------------------------------------------------- + +from graphiti_core.graph_queries import ( + VECTOR_INDEX_CANDIDATE_MULTIPLIER, + get_vector_cosine_func_query, +) + +# Cache: key = (label, attribute, entity_type), value = bool +# entity_type is 'NODE' or 'RELATIONSHIP'. +_FALKORDB_VECTOR_INDEX_CACHE: dict[tuple[str, str, str], bool] = {} + + +def _invalidate_falkordb_vector_index_cache() -> None: + """Clear the vector-index existence cache. Call after build_indices_and_constraints.""" + _FALKORDB_VECTOR_INDEX_CACHE.clear() + + +async def _falkordb_vector_index_exists( + executor: QueryExecutor, + label: str, + attribute: str, + entity_type: str, +) -> bool: + """Check whether a FalkorDB vector index exists for the given target. + + entity_type is 'NODE' for node-label indexes, 'RELATIONSHIP' for edge-type indexes. + Result is cached at module scope; call _invalidate_falkordb_vector_index_cache() + after building or dropping indexes. + """ + key = (label, attribute, entity_type) + if key in _FALKORDB_VECTOR_INDEX_CACHE: + return _FALKORDB_VECTOR_INDEX_CACHE[key] + + try: + records, _, _ = await executor.execute_query( + "CALL db.indexes() YIELD label, properties, types, entitytype " + "RETURN label, properties, types, entitytype" + ) + except Exception as e: + # If we cannot enumerate indexes, fall back to "no index" rather than + # propagating the error. The fallback cosine-math path is correct, + # just slower. + logger.warning(f"FalkorDB vector index probe failed; assuming none exist: {e}") + _FALKORDB_VECTOR_INDEX_CACHE[key] = False + return False + + found = False + for r in records: + # Records come back as dict-like rows keyed by column name (not + # tuples). Access by string keys matching the YIELD clause above. + rec_label = r.get('label') if hasattr(r, 'get') else r['label'] + rec_props = r.get('properties') if hasattr(r, 'get') else r['properties'] + rec_types = r.get('types') if hasattr(r, 'get') else r['types'] + rec_entitytype = r.get('entitytype') if hasattr(r, 'get') else r['entitytype'] + if rec_props is None: + rec_props = [] + if rec_types is None: + rec_types = {} + + if rec_label != label: + continue + if rec_entitytype is not None and rec_entitytype != entity_type: + continue + if attribute not in rec_props: + continue + + # rec_types is a dict like {attribute: ['VECTOR', ...], ...} or sometimes + # a flat list — handle both shapes. + if isinstance(rec_types, dict): + attr_types = rec_types.get(attribute, []) + else: + attr_types = rec_types + if 'VECTOR' in attr_types: + found = True + break + + _FALKORDB_VECTOR_INDEX_CACHE[key] = found + return found + + +def _falkordb_vector_node_search_cypher( + label: str, + embedding_attr: str, + search_vector_param: str, + use_index: bool, +) -> tuple[str, str]: + """Build the cypher prefix and node-binding for a node-vector search. + + Returns (prefix, node_var) where: + - prefix is the Cypher fragment that binds the node variable and a + `score` variable. With index, it's a CALL ... YIELD; without, it's + a MATCH plus WITH cosine math. + - node_var is the variable name the caller's downstream Cypher should + reference (always 'n' here for parity with the existing code). + + The caller appends WHERE filters and RETURN/ORDER BY/LIMIT as usual. + The over-fetch parameter `$candidate_k` must be passed by the caller + when use_index is True. + """ + if use_index: + return ( + f"CALL db.idx.vector.queryNodes(" + f"'{label}', '{embedding_attr}', $candidate_k, vecf32({search_vector_param})" + f") YIELD node, score " + f"WITH node AS n, score " + ), "n" + # Fallback: original cosine math path + cosine = get_vector_cosine_func_query( + f"n.{embedding_attr}", search_vector_param, GraphProvider.FALKORDB + ) + return ( + f"MATCH (n:{label}) " + f"WITH n, {cosine} AS score " + ), "n" + + +def _falkordb_vector_edge_search_cypher( + relationship_type: str, + embedding_attr: str, + search_vector_param: str, + use_index: bool, +) -> tuple[str, str]: + """Build the cypher prefix and edge-binding for an edge-vector search. + + Returns (prefix, edge_var). With the index, the procedure binds the + relationship variable; we then MATCH source and target via the existing + edge to recover (n)-[e]->(m). Without the index, it's the original + MATCH-and-cosine path. + + Variable name is 'e' for parity with existing code; source/target are + 'n' and 'm' respectively, also for parity. + """ + if use_index: + return ( + f"CALL db.idx.vector.queryRelationships(" + f"'{relationship_type}', '{embedding_attr}', $candidate_k, vecf32({search_vector_param})" + f") YIELD relationship, score " + f"MATCH (n:Entity)-[e:{relationship_type}]->(m:Entity) " + f"WHERE e = relationship " + f"WITH DISTINCT e, n, m, score " + ), "e" + # Fallback + cosine = get_vector_cosine_func_query( + f"e.{embedding_attr}", search_vector_param, GraphProvider.FALKORDB + ) + return ( + f"MATCH (n:Entity)-[e:{relationship_type}]->(m:Entity) " + f"WITH DISTINCT e, n, m, {cosine} AS score " + ), "e" + + + +# FalkorDB separator characters that break text into tokens +_SEPARATOR_MAP = str.maketrans( + { + ',': ' ', + '.': ' ', + '<': ' ', + '>': ' ', + '{': ' ', + '}': ' ', + '[': ' ', + ']': ' ', + '"': ' ', + "'": ' ', + ':': ' ', + ';': ' ', + '!': ' ', + '@': ' ', + '#': ' ', + '$': ' ', + '%': ' ', + '^': ' ', + '&': ' ', + '*': ' ', + '(': ' ', + ')': ' ', + '-': ' ', + '+': ' ', + '=': ' ', + '~': ' ', + '?': ' ', + '|': ' ', + '/': ' ', + '\\': ' ', + } +) + + +def _sanitize(query: str) -> str: + """Replace FalkorDB special characters with whitespace.""" + sanitized = query.translate(_SEPARATOR_MAP) + return ' '.join(sanitized.split()) + + +def _build_falkor_fulltext_query( + query: str, + group_ids: list[str] | None = None, + max_query_length: int = MAX_QUERY_LENGTH, +) -> str: + """Build a fulltext query string for FalkorDB using RedisSearch syntax.""" + if group_ids is None or len(group_ids) == 0: + group_filter = '' + else: + escaped_group_ids = [f'"{gid}"' for gid in group_ids] + group_values = '|'.join(escaped_group_ids) + group_filter = f'(@group_id:{group_values})' + + sanitized_query = _sanitize(query) + + # Remove stopwords and empty tokens + query_words = sanitized_query.split() + filtered_words = [word for word in query_words if word and word.lower() not in STOPWORDS] + sanitized_query = ' | '.join(filtered_words) + + if len(sanitized_query.split(' ')) + len(group_ids or '') >= max_query_length: + return '' + + full_query = group_filter + ' (' + sanitized_query + ')' + return full_query + + +class FalkorSearchOperations(SearchOperations): + # --- Node search --- + + async def node_fulltext_search( + self, + executor: QueryExecutor, + query: str, + search_filter: SearchFilters, + group_ids: list[str] | None = None, + limit: int = 10, + ) -> list[EntityNode]: + fuzzy_query = _build_falkor_fulltext_query(query, group_ids) + if fuzzy_query == '': + return [] + + filter_queries, filter_params = node_search_filter_query_constructor( + search_filter, GraphProvider.FALKORDB + ) + + if group_ids is not None: + filter_queries.append('n.group_id IN $group_ids') + filter_params['group_ids'] = group_ids + + filter_query = '' + if filter_queries: + filter_query = ' WHERE ' + (' AND '.join(filter_queries)) + + cypher = ( + get_nodes_query( + 'node_name_and_summary', '$query', limit=limit, provider=GraphProvider.FALKORDB + ) + + 'YIELD node AS n, score' + + filter_query + + """ + WITH n, score + ORDER BY score DESC + LIMIT $limit + RETURN + """ + + get_entity_node_return_query(GraphProvider.FALKORDB) + ) + + records, _, _ = await executor.execute_query( + cypher, + query=fuzzy_query, + limit=limit, + **filter_params, + ) + + return [entity_node_from_record(r) for r in records] + + async def node_similarity_search( + self, + executor: QueryExecutor, + search_vector: list[float], + search_filter: SearchFilters, + group_ids: list[str] | None = None, + limit: int = 10, + min_score: float = 0.6, + ) -> list[EntityNode]: + filter_queries, filter_params = node_search_filter_query_constructor( + search_filter, GraphProvider.FALKORDB + ) + + if group_ids is not None: + filter_queries.append('n.group_id IN $group_ids') + filter_params['group_ids'] = group_ids + + filter_query = '' + if filter_queries: + filter_query = ' WHERE ' + (' AND '.join(filter_queries)) + + # PATCHED 2026-05-02 (BirdAI vendored patch): use FalkorDB native vector + # index when available; fall back to interpreted-Cypher cosine math + # otherwise. The filter clause's position changes between paths + # (after MATCH for fallback, after YIELD for index path), but the + # filter expressions themselves are identical because they reference + # the bound variable `n` either way. + use_index = await _falkordb_vector_index_exists( + executor, 'Entity', 'name_embedding', 'NODE' + ) + prefix, _ = _falkordb_vector_node_search_cypher( + 'Entity', 'name_embedding', '$search_vector', use_index + ) + where_clauses = [] + if filter_query: + where_clauses.append(filter_query.replace(' WHERE ', '', 1).strip()) + where_clauses.append('score > $min_score') + unified_where = ' WHERE ' + ' AND '.join(where_clauses) + + cypher = ( + prefix + + unified_where + + """ + RETURN + """ + + get_entity_node_return_query(GraphProvider.FALKORDB) + + """ + ORDER BY score DESC + LIMIT $limit + """ + ) + params = dict( + search_vector=search_vector, + limit=limit, + min_score=min_score, + **filter_params, + ) + if use_index: + params['candidate_k'] = limit * VECTOR_INDEX_CANDIDATE_MULTIPLIER + records, _, _ = await executor.execute_query(cypher, **params) + + return [entity_node_from_record(r) for r in records] + + async def node_bfs_search( + self, + executor: QueryExecutor, + origin_uuids: list[str], + search_filter: SearchFilters, + max_depth: int, + group_ids: list[str] | None = None, + limit: int = 10, + ) -> list[EntityNode]: + if not origin_uuids or max_depth < 1: + return [] + + filter_queries, filter_params = node_search_filter_query_constructor( + search_filter, GraphProvider.FALKORDB + ) + + if group_ids is not None: + filter_queries.append('n.group_id IN $group_ids') + filter_queries.append('origin.group_id IN $group_ids') + filter_params['group_ids'] = group_ids + + filter_query = '' + if filter_queries: + filter_query = ' AND ' + (' AND '.join(filter_queries)) + + cypher = ( + f""" + UNWIND $bfs_origin_node_uuids AS origin_uuid + MATCH (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{max_depth}]->(n:Entity) + WHERE n.group_id = origin.group_id + """ + + filter_query + + """ + RETURN + """ + + get_entity_node_return_query(GraphProvider.FALKORDB) + + """ + LIMIT $limit + """ + ) + + records, _, _ = await executor.execute_query( + cypher, + bfs_origin_node_uuids=origin_uuids, + limit=limit, + **filter_params, + ) + + return [entity_node_from_record(r) for r in records] + + # --- Edge search --- + + async def edge_fulltext_search( + self, + executor: QueryExecutor, + query: str, + search_filter: SearchFilters, + group_ids: list[str] | None = None, + limit: int = 10, + ) -> list[EntityEdge]: + fuzzy_query = _build_falkor_fulltext_query(query, group_ids) + if fuzzy_query == '': + return [] + + filter_queries, filter_params = edge_search_filter_query_constructor( + search_filter, GraphProvider.FALKORDB + ) + + if group_ids is not None: + filter_queries.append('e.group_id IN $group_ids') + filter_params['group_ids'] = group_ids + + filter_query = '' + if filter_queries: + filter_query = ' WHERE ' + (' AND '.join(filter_queries)) + + cypher = ( + get_relationships_query( + 'edge_name_and_fact', limit=limit, provider=GraphProvider.FALKORDB + ) + + """ + YIELD relationship AS rel, score + MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity) + """ + + filter_query + + """ + WITH e, score, n, m + RETURN + """ + + get_entity_edge_return_query(GraphProvider.FALKORDB) + + """ + ORDER BY score DESC + LIMIT $limit + """ + ) + + records, _, _ = await executor.execute_query( + cypher, + query=fuzzy_query, + limit=limit, + **filter_params, + ) + + return [entity_edge_from_record(r) for r in records] + + async def edge_similarity_search( + self, + executor: QueryExecutor, + search_vector: list[float], + source_node_uuid: str | None, + target_node_uuid: str | None, + search_filter: SearchFilters, + group_ids: list[str] | None = None, + limit: int = 10, + min_score: float = 0.6, + ) -> list[EntityEdge]: + filter_queries, filter_params = edge_search_filter_query_constructor( + search_filter, GraphProvider.FALKORDB + ) + + if group_ids is not None: + filter_queries.append('e.group_id IN $group_ids') + filter_params['group_ids'] = group_ids + + if source_node_uuid is not None: + filter_params['source_uuid'] = source_node_uuid + filter_queries.append('n.uuid = $source_uuid') + + if target_node_uuid is not None: + filter_params['target_uuid'] = target_node_uuid + filter_queries.append('m.uuid = $target_uuid') + + filter_query = '' + if filter_queries: + filter_query = ' WHERE ' + (' AND '.join(filter_queries)) + + # PATCHED 2026-05-02 (BirdAI vendored patch): use FalkorDB native vector + # index on RELATES_TO.fact_embedding when available. The unindexed + # fallback is the same MATCH-and-cosine math that previously hung + # for 6+ minutes on a 4,000-entity graph; this is the load-bearing + # call site that motivated the patch. + use_index = await _falkordb_vector_index_exists( + executor, 'RELATES_TO', 'fact_embedding', 'RELATIONSHIP' + ) + prefix, _ = _falkordb_vector_edge_search_cypher( + 'RELATES_TO', 'fact_embedding', '$search_vector', use_index + ) + where_clauses = [] + if filter_query: + where_clauses.append(filter_query.replace(' WHERE ', '', 1).strip()) + where_clauses.append('score > $min_score') + unified_where = ' WHERE ' + ' AND '.join(where_clauses) + + cypher = ( + prefix + + unified_where + + """ + RETURN + """ + + get_entity_edge_return_query(GraphProvider.FALKORDB) + + """ + ORDER BY score DESC + LIMIT $limit + """ + ) + params = dict( + search_vector=search_vector, + limit=limit, + min_score=min_score, + **filter_params, + ) + if use_index: + params['candidate_k'] = limit * VECTOR_INDEX_CANDIDATE_MULTIPLIER + records, _, _ = await executor.execute_query(cypher, **params) + + return [entity_edge_from_record(r) for r in records] + + async def edge_bfs_search( + self, + executor: QueryExecutor, + origin_uuids: list[str], + max_depth: int, + search_filter: SearchFilters, + group_ids: list[str] | None = None, + limit: int = 10, + ) -> list[EntityEdge]: + if not origin_uuids: + return [] + + filter_queries, filter_params = edge_search_filter_query_constructor( + search_filter, GraphProvider.FALKORDB + ) + + if group_ids is not None: + filter_queries.append('e.group_id IN $group_ids') + filter_params['group_ids'] = group_ids + + filter_query = '' + if filter_queries: + filter_query = ' WHERE ' + (' AND '.join(filter_queries)) + + cypher = ( + f""" + UNWIND $bfs_origin_node_uuids AS origin_uuid + MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{max_depth}]->(:Entity) + UNWIND relationships(path) AS rel + MATCH (n:Entity)-[e:RELATES_TO {{uuid: rel.uuid}}]-(m:Entity) + """ + + filter_query + + """ + RETURN DISTINCT + """ + + get_entity_edge_return_query(GraphProvider.FALKORDB) + + """ + LIMIT $limit + """ + ) + + records, _, _ = await executor.execute_query( + cypher, + bfs_origin_node_uuids=origin_uuids, + depth=max_depth, + limit=limit, + **filter_params, + ) + + return [entity_edge_from_record(r) for r in records] + + # --- Episode search --- + + async def episode_fulltext_search( + self, + executor: QueryExecutor, + query: str, + search_filter: SearchFilters, # noqa: ARG002 + group_ids: list[str] | None = None, + limit: int = 10, + ) -> list[EpisodicNode]: + fuzzy_query = _build_falkor_fulltext_query(query, group_ids) + if fuzzy_query == '': + return [] + + filter_params: dict[str, Any] = {} + group_filter_query = '' + if group_ids is not None: + group_filter_query += '\nAND e.group_id IN $group_ids' + filter_params['group_ids'] = group_ids + + cypher = ( + get_nodes_query( + 'episode_content', '$query', limit=limit, provider=GraphProvider.FALKORDB + ) + + """ + YIELD node AS episode, score + MATCH (e:Episodic) + WHERE e.uuid = episode.uuid + """ + + group_filter_query + + """ + RETURN + """ + + EPISODIC_NODE_RETURN + + """ + ORDER BY score DESC + LIMIT $limit + """ + ) + + records, _, _ = await executor.execute_query( + cypher, query=fuzzy_query, limit=limit, **filter_params + ) + + return [episodic_node_from_record(r) for r in records] + + # --- Community search --- + + async def community_fulltext_search( + self, + executor: QueryExecutor, + query: str, + group_ids: list[str] | None = None, + limit: int = 10, + ) -> list[CommunityNode]: + fuzzy_query = _build_falkor_fulltext_query(query, group_ids) + if fuzzy_query == '': + return [] + + filter_params: dict[str, Any] = {} + group_filter_query = '' + if group_ids is not None: + group_filter_query = 'WHERE c.group_id IN $group_ids' + filter_params['group_ids'] = group_ids + + cypher = ( + get_nodes_query( + 'community_name', '$query', limit=limit, provider=GraphProvider.FALKORDB + ) + + """ + YIELD node AS c, score + WITH c, score + """ + + group_filter_query + + """ + RETURN + """ + + COMMUNITY_NODE_RETURN + + """ + ORDER BY score DESC + LIMIT $limit + """ + ) + + records, _, _ = await executor.execute_query( + cypher, query=fuzzy_query, limit=limit, **filter_params + ) + + return [community_node_from_record(r) for r in records] + + async def community_similarity_search( + self, + executor: QueryExecutor, + search_vector: list[float], + group_ids: list[str] | None = None, + limit: int = 10, + min_score: float = 0.6, + ) -> list[CommunityNode]: + query_params: dict[str, Any] = {} + + group_filter_query = '' + if group_ids is not None: + group_filter_query += ' WHERE c.group_id IN $group_ids' + query_params['group_ids'] = group_ids + + # PATCHED 2026-05-02 (BirdAI vendored patch): use FalkorDB native vector + # index on Community.name_embedding when available. Note: the existing + # filter is built into `group_filter_query` (already prefixed with + # ' WHERE ' if non-empty) and uses variable `c`. The dispatcher binds + # the node as `n` for parity with the helper signature, then we + # re-bind to `c` via WITH so the rest of the query is unchanged. + use_index = await _falkordb_vector_index_exists( + executor, 'Community', 'name_embedding', 'NODE' + ) + prefix, _ = _falkordb_vector_node_search_cypher( + 'Community', 'name_embedding', '$search_vector', use_index + ) + prefix = prefix + ' WITH n AS c, score ' + where_clauses = [] + if group_filter_query: + where_clauses.append(group_filter_query.replace(' WHERE ', '', 1).strip()) + where_clauses.append('score > $min_score') + unified_where = ' WHERE ' + ' AND '.join(where_clauses) + + cypher = ( + prefix + + unified_where + + """ + RETURN + """ + + COMMUNITY_NODE_RETURN + + """ + ORDER BY score DESC + LIMIT $limit + """ + ) + params = dict( + search_vector=search_vector, + limit=limit, + min_score=min_score, + **query_params, + ) + if use_index: + params['candidate_k'] = limit * VECTOR_INDEX_CANDIDATE_MULTIPLIER + records, _, _ = await executor.execute_query(cypher, **params) + + return [community_node_from_record(r) for r in records] + + # --- Rerankers --- + + async def node_distance_reranker( + self, + executor: QueryExecutor, + node_uuids: list[str], + center_node_uuid: str, + min_score: float = 0, + ) -> list[EntityNode]: + filtered_uuids = [u for u in node_uuids if u != center_node_uuid] + scores: dict[str, float] = {center_node_uuid: 0.0} + + cypher = """ + UNWIND $node_uuids AS node_uuid + MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid}) + RETURN 1 AS score, node_uuid AS uuid + """ + + results, _, _ = await executor.execute_query( + cypher, + node_uuids=filtered_uuids, + center_uuid=center_node_uuid, + ) + + for result in results: + scores[result['uuid']] = result['score'] + + for uuid in filtered_uuids: + if uuid not in scores: + scores[uuid] = float('inf') + + filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid]) + + if center_node_uuid in node_uuids: + scores[center_node_uuid] = 0.1 + filtered_uuids = [center_node_uuid] + filtered_uuids + + reranked_uuids = [u for u in filtered_uuids if (1 / scores[u]) >= min_score] + + if not reranked_uuids: + return [] + + get_query = """ + MATCH (n:Entity) + WHERE n.uuid IN $uuids + RETURN + """ + get_entity_node_return_query(GraphProvider.FALKORDB) + + records, _, _ = await executor.execute_query(get_query, uuids=reranked_uuids) + + node_map = {r['uuid']: entity_node_from_record(r) for r in records} + return [node_map[u] for u in reranked_uuids if u in node_map] + + async def episode_mentions_reranker( + self, + executor: QueryExecutor, + node_uuids: list[str], + min_score: float = 0, + ) -> list[EntityNode]: + if not node_uuids: + return [] + + scores: dict[str, float] = {} + + results, _, _ = await executor.execute_query( + """ + UNWIND $node_uuids AS node_uuid + MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid}) + RETURN count(*) AS score, n.uuid AS uuid + """, + node_uuids=node_uuids, + ) + + for result in results: + scores[result['uuid']] = result['score'] + + for uuid in node_uuids: + if uuid not in scores: + scores[uuid] = float('inf') + + sorted_uuids = list(node_uuids) + sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid]) + + reranked_uuids = [u for u in sorted_uuids if scores[u] >= min_score] + + if not reranked_uuids: + return [] + + get_query = """ + MATCH (n:Entity) + WHERE n.uuid IN $uuids + RETURN + """ + get_entity_node_return_query(GraphProvider.FALKORDB) + + records, _, _ = await executor.execute_query(get_query, uuids=reranked_uuids) + + node_map = {r['uuid']: entity_node_from_record(r) for r in records} + return [node_map[u] for u in reranked_uuids if u in node_map] + + # --- Filter builders --- + + def build_node_search_filters(self, search_filters: SearchFilters) -> Any: + filter_queries, filter_params = node_search_filter_query_constructor( + search_filters, GraphProvider.FALKORDB + ) + return {'filter_queries': filter_queries, 'filter_params': filter_params} + + def build_edge_search_filters(self, search_filters: SearchFilters) -> Any: + filter_queries, filter_params = edge_search_filter_query_constructor( + search_filters, GraphProvider.FALKORDB + ) + return {'filter_queries': filter_queries, 'filter_params': filter_params} + + # --- Fulltext query builder --- + + def build_fulltext_query( + self, + query: str, + group_ids: list[str] | None = None, + max_query_length: int = MAX_QUERY_LENGTH, + ) -> str: + return _build_falkor_fulltext_query(query, group_ids, max_query_length) diff --git a/graphiti_patches/graphiti_core/driver/falkordb_driver.py b/graphiti_patches/graphiti_core/driver/falkordb_driver.py new file mode 100644 index 0000000..824cc2b --- /dev/null +++ b/graphiti_patches/graphiti_core/driver/falkordb_driver.py @@ -0,0 +1,444 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import asyncio +import datetime +import logging +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from falkordb import Graph as FalkorGraph + from falkordb.asyncio import FalkorDB +else: + try: + from falkordb import Graph as FalkorGraph + from falkordb.asyncio import FalkorDB + except ImportError: + # If falkordb is not installed, raise an ImportError + raise ImportError( + 'falkordb is required for FalkorDriver. ' + 'Install it with: pip install graphiti-core[falkordb]' + ) from None + +from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider +from graphiti_core.driver.falkordb import STOPWORDS as STOPWORDS +from graphiti_core.driver.falkordb.operations.community_edge_ops import ( + FalkorCommunityEdgeOperations, +) +from graphiti_core.driver.falkordb.operations.community_node_ops import ( + FalkorCommunityNodeOperations, +) +from graphiti_core.driver.falkordb.operations.entity_edge_ops import FalkorEntityEdgeOperations +from graphiti_core.driver.falkordb.operations.entity_node_ops import FalkorEntityNodeOperations +from graphiti_core.driver.falkordb.operations.episode_node_ops import FalkorEpisodeNodeOperations +from graphiti_core.driver.falkordb.operations.episodic_edge_ops import FalkorEpisodicEdgeOperations +from graphiti_core.driver.falkordb.operations.graph_ops import FalkorGraphMaintenanceOperations +from graphiti_core.driver.falkordb.operations.has_episode_edge_ops import ( + FalkorHasEpisodeEdgeOperations, +) +from graphiti_core.driver.falkordb.operations.next_episode_edge_ops import ( + FalkorNextEpisodeEdgeOperations, +) +from graphiti_core.driver.falkordb.operations.saga_node_ops import FalkorSagaNodeOperations +from graphiti_core.driver.falkordb.operations.search_ops import FalkorSearchOperations +from graphiti_core.driver.operations.community_edge_ops import CommunityEdgeOperations +from graphiti_core.driver.operations.community_node_ops import CommunityNodeOperations +from graphiti_core.driver.operations.entity_edge_ops import EntityEdgeOperations +from graphiti_core.driver.operations.entity_node_ops import EntityNodeOperations +from graphiti_core.driver.operations.episode_node_ops import EpisodeNodeOperations +from graphiti_core.driver.operations.episodic_edge_ops import EpisodicEdgeOperations +from graphiti_core.driver.operations.graph_ops import GraphMaintenanceOperations +from graphiti_core.driver.operations.has_episode_edge_ops import HasEpisodeEdgeOperations +from graphiti_core.driver.operations.next_episode_edge_ops import NextEpisodeEdgeOperations +from graphiti_core.driver.operations.saga_node_ops import SagaNodeOperations +from graphiti_core.driver.operations.search_ops import SearchOperations +from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices, get_vector_indices +from graphiti_core.helpers import validate_group_ids +from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings + +logger = logging.getLogger(__name__) + + +class FalkorDriverSession(GraphDriverSession): + provider = GraphProvider.FALKORDB + + def __init__(self, graph: FalkorGraph): + self.graph = graph + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + # No cleanup needed for Falkor, but method must exist + pass + + async def close(self): + # No explicit close needed for FalkorDB, but method must exist + pass + + async def execute_write(self, func, *args, **kwargs): + # Directly await the provided async function with `self` as the transaction/session + return await func(self, *args, **kwargs) + + async def run(self, query: str | list, **kwargs: Any) -> Any: + # FalkorDB does not support argument for Label Set, so it's converted into an array of queries + if isinstance(query, list): + for cypher, params in query: + params = convert_datetimes_to_strings(params) + await self.graph.query(str(cypher), params) # type: ignore[reportUnknownArgumentType] + else: + params = dict(kwargs) + params = convert_datetimes_to_strings(params) + await self.graph.query(str(query), params) # type: ignore[reportUnknownArgumentType] + # Assuming `graph.query` is async (ideal); otherwise, wrap in executor + return None + + +class FalkorDriver(GraphDriver): + provider = GraphProvider.FALKORDB + default_group_id: str = '\\_' + fulltext_syntax: str = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries + aoss_client: None = None + + def __init__( + self, + host: str = 'localhost', + port: int = 6379, + username: str | None = None, + password: str | None = None, + falkor_db: FalkorDB | None = None, + database: str = 'default_db', + ): + """ + Initialize the FalkorDB driver. + + FalkorDB is a multi-tenant graph database. + To connect, provide the host and port. + The default parameters assume a local (on-premises) FalkorDB instance. + + Args: + host (str): The host where FalkorDB is running. + port (int): The port on which FalkorDB is listening. + username (str | None): The username for authentication (if required). + password (str | None): The password for authentication (if required). + falkor_db (FalkorDB | None): An existing FalkorDB instance to use instead of creating a new one. + database (str): The name of the database to connect to. Defaults to 'default_db'. + """ + super().__init__() + self._database = database + if falkor_db is not None: + # If a FalkorDB instance is provided, use it directly + self.client = falkor_db + else: + self.client = FalkorDB(host=host, port=port, username=username, password=password) + + # Instantiate FalkorDB operations + self._entity_node_ops = FalkorEntityNodeOperations() + self._episode_node_ops = FalkorEpisodeNodeOperations() + self._community_node_ops = FalkorCommunityNodeOperations() + self._saga_node_ops = FalkorSagaNodeOperations() + self._entity_edge_ops = FalkorEntityEdgeOperations() + self._episodic_edge_ops = FalkorEpisodicEdgeOperations() + self._community_edge_ops = FalkorCommunityEdgeOperations() + self._has_episode_edge_ops = FalkorHasEpisodeEdgeOperations() + self._next_episode_edge_ops = FalkorNextEpisodeEdgeOperations() + self._search_ops = FalkorSearchOperations() + self._graph_ops = FalkorGraphMaintenanceOperations() + + # Schedule the indices and constraints to be built + try: + # Try to get the current event loop + loop = asyncio.get_running_loop() + # Schedule the build_indices_and_constraints to run + loop.create_task(self.build_indices_and_constraints()) + except RuntimeError: + # No event loop running, this will be handled later + pass + + # --- Operations properties --- + + @property + def entity_node_ops(self) -> EntityNodeOperations: + return self._entity_node_ops + + @property + def episode_node_ops(self) -> EpisodeNodeOperations: + return self._episode_node_ops + + @property + def community_node_ops(self) -> CommunityNodeOperations: + return self._community_node_ops + + @property + def saga_node_ops(self) -> SagaNodeOperations: + return self._saga_node_ops + + @property + def entity_edge_ops(self) -> EntityEdgeOperations: + return self._entity_edge_ops + + @property + def episodic_edge_ops(self) -> EpisodicEdgeOperations: + return self._episodic_edge_ops + + @property + def community_edge_ops(self) -> CommunityEdgeOperations: + return self._community_edge_ops + + @property + def has_episode_edge_ops(self) -> HasEpisodeEdgeOperations: + return self._has_episode_edge_ops + + @property + def next_episode_edge_ops(self) -> NextEpisodeEdgeOperations: + return self._next_episode_edge_ops + + @property + def search_ops(self) -> SearchOperations: + return self._search_ops + + @property + def graph_ops(self) -> GraphMaintenanceOperations: + return self._graph_ops + + def _get_graph(self, graph_name: str | None) -> FalkorGraph: + # FalkorDB requires a non-None database name for multi-tenant graphs; the default is "default_db" + if graph_name is None: + graph_name = self._database + return self.client.select_graph(graph_name) + + async def execute_query(self, cypher_query_, **kwargs: Any): + graph = self._get_graph(self._database) + + # Convert datetime objects to ISO strings (FalkorDB does not support datetime objects directly) + params = convert_datetimes_to_strings(dict(kwargs)) + + try: + result = await graph.query(cypher_query_, params) # type: ignore[reportUnknownArgumentType] + except Exception as e: + if 'already indexed' in str(e): + # check if index already exists + logger.info(f'Index already exists: {e}') + return None + logger.error(f'Error executing FalkorDB query: {e}\n{cypher_query_}\n{params}') + raise + + # Convert the result header to a list of strings + header = [h[1] for h in result.header] + + # Convert FalkorDB's result format (list of lists) to the format expected by Graphiti (list of dicts) + records = [] + for row in result.result_set: + record = {} + for i, field_name in enumerate(header): + if i < len(row): + record[field_name] = row[i] + else: + # If there are more fields in header than values in row, set to None + record[field_name] = None + records.append(record) + + return records, header, None + + def session(self, database: str | None = None) -> GraphDriverSession: + return FalkorDriverSession(self._get_graph(database)) + + async def close(self) -> None: + """Close the driver connection.""" + if hasattr(self.client, 'aclose'): + await self.client.aclose() # type: ignore[reportUnknownMemberType] + elif hasattr(self.client.connection, 'aclose'): + await self.client.connection.aclose() + elif hasattr(self.client.connection, 'close'): + await self.client.connection.close() + + async def delete_all_indexes(self) -> None: + result = await self.execute_query('CALL db.indexes()') + if not result: + return + + records, _, _ = result + drop_tasks = [] + + for record in records: + label = record['label'] + entity_type = record['entitytype'] + + for field_name, index_type in record['types'].items(): + if 'RANGE' in index_type: + drop_tasks.append(self.execute_query(f'DROP INDEX ON :{label}({field_name})')) + elif 'FULLTEXT' in index_type: + if entity_type == 'NODE': + drop_tasks.append( + self.execute_query( + f'DROP FULLTEXT INDEX FOR (n:{label}) ON (n.{field_name})' + ) + ) + elif entity_type == 'RELATIONSHIP': + drop_tasks.append( + self.execute_query( + f'DROP FULLTEXT INDEX FOR ()-[e:{label}]-() ON (e.{field_name})' + ) + ) + + if drop_tasks: + await asyncio.gather(*drop_tasks) + + async def build_indices_and_constraints(self, delete_existing=False): + if delete_existing: + await self.delete_all_indexes() + # PATCHED 2026-05-02 (BirdAI vendored patch): add vector indexes alongside + # range and fulltext. FalkorDB supports native vector indexes via + # db.idx.vector.queryNodes / queryRelationships; without these, similarity + # search runs as full-table-scan cosine math in interpreted Cypher. + index_queries = ( + get_range_indices(self.provider) + + get_fulltext_indices(self.provider) + + get_vector_indices(self.provider) + ) + for query in index_queries: + await self.execute_query(query) + # Invalidate the search_ops vector-index existence cache so subsequent + # similarity queries re-probe and discover the indexes we just built. + try: + from graphiti_core.driver.falkordb.operations.search_ops import ( + _invalidate_falkordb_vector_index_cache, + ) + _invalidate_falkordb_vector_index_cache() + except ImportError: + # search_ops module not yet imported (cold start); cache is empty + # by default, so no invalidation needed. + pass + + def clone(self, database: str) -> 'GraphDriver': + """ + Returns a shallow copy of this driver with a different default database. + Reuses the same connection (e.g. FalkorDB, Neo4j). + """ + if database == self._database: + cloned = self + elif database == self.default_group_id: + cloned = FalkorDriver(falkor_db=self.client) + else: + # Create a new instance of FalkorDriver with the same connection but a different database + cloned = FalkorDriver(falkor_db=self.client, database=database) + + return cloned + + async def health_check(self) -> None: + """Check FalkorDB connectivity by running a simple query.""" + try: + await self.execute_query('MATCH (n) RETURN 1 LIMIT 1') + return None + except Exception as e: + print(f'FalkorDB health check failed: {e}') + raise + + @staticmethod + def convert_datetimes_to_strings(obj): + if isinstance(obj, dict): + return {k: FalkorDriver.convert_datetimes_to_strings(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [FalkorDriver.convert_datetimes_to_strings(item) for item in obj] + elif isinstance(obj, tuple): + return tuple(FalkorDriver.convert_datetimes_to_strings(item) for item in obj) + elif isinstance(obj, datetime): + return obj.isoformat() + else: + return obj + + def sanitize(self, query: str) -> str: + """ + Replace FalkorDB special characters with whitespace. + Based on FalkorDB tokenization rules: ,.<>{}[]"':;!@#$%^&*()-+=~ + """ + # FalkorDB separator characters that break text into tokens + separator_map = str.maketrans( + { + ',': ' ', + '.': ' ', + '<': ' ', + '>': ' ', + '{': ' ', + '}': ' ', + '[': ' ', + ']': ' ', + '"': ' ', + "'": ' ', + ':': ' ', + ';': ' ', + '!': ' ', + '@': ' ', + '#': ' ', + '$': ' ', + '%': ' ', + '^': ' ', + '&': ' ', + '*': ' ', + '(': ' ', + ')': ' ', + '-': ' ', + '+': ' ', + '=': ' ', + '~': ' ', + '?': ' ', + '|': ' ', + '/': ' ', + '\\': ' ', + } + ) + sanitized = query.translate(separator_map) + # Clean up multiple spaces + sanitized = ' '.join(sanitized.split()) + return sanitized + + def build_fulltext_query( + self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128 + ) -> str: + """ + Build a fulltext query string for FalkorDB using RedisSearch syntax. + FalkorDB uses RedisSearch-like syntax where: + - Field queries use @ prefix: @field:value + - Multiple values for same field: (@field:value1|value2) + - Text search doesn't need @ prefix for content fields + - AND is implicit with space: (@group_id:value) (text) + - OR uses pipe within parentheses: (@group_id:value1|value2) + """ + validate_group_ids(group_ids) + + if group_ids is None or len(group_ids) == 0: + group_filter = '' + else: + # Escape group_ids with quotes to prevent RediSearch syntax errors + # with reserved words like "main" or special characters like hyphens + escaped_group_ids = [f'"{gid}"' for gid in group_ids] + group_values = '|'.join(escaped_group_ids) + group_filter = f'(@group_id:{group_values})' + + sanitized_query = self.sanitize(query) + + # Remove stopwords and empty tokens from the sanitized query + query_words = sanitized_query.split() + filtered_words = [word for word in query_words if word and word.lower() not in STOPWORDS] + sanitized_query = ' | '.join(filtered_words) + + # If the query is too long return no query + if len(sanitized_query.split(' ')) + len(group_ids or '') >= max_query_length: + return '' + + full_query = group_filter + ' (' + sanitized_query + ')' + + return full_query diff --git a/graphiti_patches/graphiti_core/graph_queries.py b/graphiti_patches/graphiti_core/graph_queries.py new file mode 100644 index 0000000..f825a26 --- /dev/null +++ b/graphiti_patches/graphiti_core/graph_queries.py @@ -0,0 +1,242 @@ +""" +Database query utilities for different graph database backends. + +This module provides database-agnostic query generation for Neo4j and FalkorDB, +supporting index creation, fulltext search, and bulk operations. + +PATCHED for FalkorDB native vector index support (BirdAI vendored patch, +2026-05-02). Adds: +- get_vector_indices(): CREATE VECTOR INDEX statements for FalkorDB +- get_vector_search_query(): Cypher fragment for vector similarity using + FalkorDB's db.idx.vector procedures, with fallback to cosine math when + the index does not yet exist +- VECTOR_INDEX_CANDIDATE_MULTIPLIER: over-fetch factor for vector index + queries to handle filter rejections after index lookup + +No changes to Neo4j or Kuzu code paths. +""" + +from typing_extensions import LiteralString + +from graphiti_core.driver.driver import GraphProvider + +# Mapping from Neo4j fulltext index names to FalkorDB node labels +NEO4J_TO_FALKORDB_MAPPING = { + 'node_name_and_summary': 'Entity', + 'community_name': 'Community', + 'episode_content': 'Episodic', + 'edge_name_and_fact': 'RELATES_TO', +} +# Mapping from fulltext index names to Kuzu node labels +INDEX_TO_LABEL_KUZU_MAPPING = { + 'node_name_and_summary': 'Entity', + 'community_name': 'Community', + 'episode_content': 'Episodic', + 'edge_name_and_fact': 'RelatesToNode_', +} + +# Vector index over-fetch multiplier. When a vector index search is +# combined with WHERE filters (group_id, source_uuid, etc.), some of +# the top-k index results may be filtered out. Over-fetching by this +# factor preserves recall against the final LIMIT after filtering. +# Conservative default; tunable per-deployment by editing this constant +# or via environment-variable override at the driver level (future). +VECTOR_INDEX_CANDIDATE_MULTIPLIER = 5 + + +def get_range_indices(provider: GraphProvider) -> list[LiteralString]: + if provider == GraphProvider.FALKORDB: + return [ + # Entity node + 'CREATE INDEX FOR (n:Entity) ON (n.uuid, n.group_id, n.name, n.created_at)', + # Episodic node + 'CREATE INDEX FOR (n:Episodic) ON (n.uuid, n.group_id, n.created_at, n.valid_at)', + # Community node + 'CREATE INDEX FOR (n:Community) ON (n.uuid)', + # Saga node + 'CREATE INDEX FOR (n:Saga) ON (n.uuid, n.group_id, n.name)', + # RELATES_TO edge + 'CREATE INDEX FOR ()-[e:RELATES_TO]-() ON (e.uuid, e.group_id, e.name, e.created_at, e.expired_at, e.valid_at, e.invalid_at)', + # MENTIONS edge + 'CREATE INDEX FOR ()-[e:MENTIONS]-() ON (e.uuid, e.group_id)', + # HAS_MEMBER edge + 'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)', + # HAS_EPISODE edge + 'CREATE INDEX FOR ()-[e:HAS_EPISODE]-() ON (e.uuid, e.group_id)', + # NEXT_EPISODE edge + 'CREATE INDEX FOR ()-[e:NEXT_EPISODE]-() ON (e.uuid, e.group_id)', + ] + + if provider == GraphProvider.KUZU: + return [] + + return [ + 'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)', + 'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)', + 'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)', + 'CREATE INDEX saga_uuid IF NOT EXISTS FOR (n:Saga) ON (n.uuid)', + 'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)', + 'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)', + 'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)', + 'CREATE INDEX has_episode_uuid IF NOT EXISTS FOR ()-[e:HAS_EPISODE]-() ON (e.uuid)', + 'CREATE INDEX next_episode_uuid IF NOT EXISTS FOR ()-[e:NEXT_EPISODE]-() ON (e.uuid)', + 'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)', + 'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)', + 'CREATE INDEX community_group_id IF NOT EXISTS FOR (n:Community) ON (n.group_id)', + 'CREATE INDEX saga_group_id IF NOT EXISTS FOR (n:Saga) ON (n.group_id)', + 'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)', + 'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)', + 'CREATE INDEX has_episode_group_id IF NOT EXISTS FOR ()-[e:HAS_EPISODE]-() ON (e.group_id)', + 'CREATE INDEX next_episode_group_id IF NOT EXISTS FOR ()-[e:NEXT_EPISODE]-() ON (e.group_id)', + 'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)', + 'CREATE INDEX saga_name IF NOT EXISTS FOR (n:Saga) ON (n.name)', + 'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)', + 'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)', + 'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)', + 'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)', + 'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)', + 'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)', + 'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)', + 'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)', + ] + + +def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]: + if provider == GraphProvider.FALKORDB: + from typing import cast + + from graphiti_core.driver.falkordb import STOPWORDS + + # Convert to string representation for embedding in queries + stopwords_str = str(STOPWORDS) + + # Use type: ignore to satisfy LiteralString requirement while maintaining single source of truth + return cast( + list[LiteralString], + [ + f"""CALL db.idx.fulltext.createNodeIndex( + {{ + label: 'Episodic', + stopwords: {stopwords_str} + }}, + 'content', 'source', 'source_description', 'group_id' + )""", + f"""CALL db.idx.fulltext.createNodeIndex( + {{ + label: 'Entity', + stopwords: {stopwords_str} + }}, + 'name', 'summary', 'group_id' + )""", + f"""CALL db.idx.fulltext.createNodeIndex( + {{ + label: 'Community', + stopwords: {stopwords_str} + }}, + 'name', 'group_id' + )""", + """CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""", + ], + ) + + if provider == GraphProvider.KUZU: + return [ + "CALL CREATE_FTS_INDEX('Episodic', 'episode_content', ['content', 'source', 'source_description']);", + "CALL CREATE_FTS_INDEX('Entity', 'node_name_and_summary', ['name', 'summary']);", + "CALL CREATE_FTS_INDEX('Community', 'community_name', ['name']);", + "CALL CREATE_FTS_INDEX('RelatesToNode_', 'edge_name_and_fact', ['name', 'fact']);", + ] + + return [ + """CREATE FULLTEXT INDEX episode_content IF NOT EXISTS + FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""", + """CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS + FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""", + """CREATE FULLTEXT INDEX community_name IF NOT EXISTS + FOR (n:Community) ON EACH [n.name, n.group_id]""", + """CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS + FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""", + ] + + +def get_vector_indices(provider: GraphProvider, dimension: int = 384) -> list[LiteralString]: + """Return CREATE VECTOR INDEX statements for the given provider. + + For FalkorDB: creates HNSW vector indexes on Entity.name_embedding, + RELATES_TO.fact_embedding, and Community.name_embedding. Backed by + FalkorDB's native vector index (db.idx.vector.queryNodes / + queryRelationships). + + For Neo4j and Kuzu: returns an empty list. Those backends create vector + indexes via different mechanisms (Neo4j auto-creates them when needed + via its vector.similarity.cosine function; Kuzu uses array_cosine_similarity + and does not require pre-built vector indexes for graphiti-core's usage). + + Args: + provider: The graph database provider. + dimension: Embedding dimension. Defaults to 384 (all-MiniLM-L6-v2). + Embedders with different dimensions should pass their own value + through driver configuration. graphiti-core's default embedder + is 1536 (OpenAI ada-002); BirdAI uses 384 (sentence-transformers). + + Returns: + List of CREATE VECTOR INDEX statements. Idempotent at FalkorDB level + if the index already exists with matching options. + """ + if provider == GraphProvider.FALKORDB: + from typing import cast + return cast( + list[LiteralString], + [ + f"CREATE VECTOR INDEX FOR (n:Entity) ON (n.name_embedding) " + f"OPTIONS {{dimension: {dimension}, similarityFunction: 'cosine'}}", + f"CREATE VECTOR INDEX FOR ()-[e:RELATES_TO]-() ON (e.fact_embedding) " + f"OPTIONS {{dimension: {dimension}, similarityFunction: 'cosine'}}", + f"CREATE VECTOR INDEX FOR (n:Community) ON (n.name_embedding) " + f"OPTIONS {{dimension: {dimension}, similarityFunction: 'cosine'}}", + ], + ) + + return [] + + +def get_nodes_query(name: str, query: str, limit: int, provider: GraphProvider) -> str: + if provider == GraphProvider.FALKORDB: + label = NEO4J_TO_FALKORDB_MAPPING[name] + return f"CALL db.idx.fulltext.queryNodes('{label}', {query})" + + if provider == GraphProvider.KUZU: + label = INDEX_TO_LABEL_KUZU_MAPPING[name] + return f"CALL QUERY_FTS_INDEX('{label}', '{name}', {query}, TOP := $limit)" + + return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})' + + +def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str: + """Return a Cypher fragment for cosine similarity score in [0, 1]. + + PRESERVED for backward compatibility and as fallback when vector indexes + do not yet exist on the FalkorDB backend. New code paths should prefer + get_vector_search_query() which uses the native vector index when + available. + """ + if provider == GraphProvider.FALKORDB: + # FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity + return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2' + + if provider == GraphProvider.KUZU: + return f'array_cosine_similarity({vec1}, {vec2})' + + return f'vector.similarity.cosine({vec1}, {vec2})' + + +def get_relationships_query(name: str, limit: int, provider: GraphProvider) -> str: + if provider == GraphProvider.FALKORDB: + label = NEO4J_TO_FALKORDB_MAPPING[name] + return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)" + + if provider == GraphProvider.KUZU: + label = INDEX_TO_LABEL_KUZU_MAPPING[name] + return f"CALL QUERY_FTS_INDEX('{label}', '{name}', cast($query AS STRING), TOP := $limit)" + + return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'